web.lua 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. local ROOT = '/home/ubuntu/waifu2x'
  2. _G.TURBO_SSL = true -- Enable SSL
  3. local turbo = require 'turbo'
  4. local uuid = require 'uuid'
  5. local ffi = require 'ffi'
  6. local md5 = require 'md5'
  7. require 'torch'
  8. require 'cudnn'
  9. require 'pl'
  10. torch.setdefaulttensortype('torch.FloatTensor')
  11. torch.setnumthreads(4)
  12. package.path = package.path .. ";" .. path.join(ROOT, 'lib', '?.lua')
  13. require 'LeakyReLU'
  14. local iproc = require 'iproc'
  15. local reconstract = require 'reconstract'
  16. local image_loader = require 'image_loader'
  17. local noise1_model = torch.load(path.join(ROOT, "models", "noise1_model.t7"), "ascii")
  18. local noise2_model = torch.load(path.join(ROOT, "models", "noise2_model.t7"), "ascii")
  19. local scale20_model = torch.load(path.join(ROOT, "models", "scale2.0x_model.t7"), "ascii")
  20. local USE_CACHE = true
  21. local CACHE_DIR = path.join(ROOT, "cache")
  22. local MAX_NOISE_IMAGE = 2560 * 2560
  23. local MAX_SCALE_IMAGE = 1280 * 1280
  24. local CURL_OPTIONS = {
  25. request_timeout = 10,
  26. connect_timeout = 5,
  27. allow_redirects = true,
  28. max_redirects = 1
  29. }
  30. local CURL_MAX_SIZE = 2 * 1024 * 1024
  31. local BLOCK_OFFSET = 7 -- see srcnn.lua
  32. local function valid_size(x, scale)
  33. if scale == 0 then
  34. return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE
  35. else
  36. return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE
  37. end
  38. end
  39. local function get_image(req)
  40. local file = req:get_argument("file", "")
  41. local url = req:get_argument("url", "")
  42. local blob = nil
  43. local img = nil
  44. if file and file:len() > 0 then
  45. blob = file
  46. img = image_loader.decode_float(blob)
  47. elseif url and url:len() > 0 then
  48. local res = coroutine.yield(
  49. turbo.async.HTTPClient({verify_ca=false},
  50. nil,
  51. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  52. )
  53. if res.code == 200 then
  54. local content_type = res.headers:get("Content-Type", true)
  55. if type(content_type) == "table" then
  56. content_type = content_type[1]
  57. end
  58. if content_type and content_type:find("image") then
  59. blob = res.body
  60. img = image_loader.decode_float(blob)
  61. end
  62. end
  63. end
  64. return img, blob
  65. end
  66. local function apply_denoise1(x)
  67. return reconstract(noise1_model, x, BLOCK_OFFSET)
  68. end
  69. local function apply_denoise2(x)
  70. return reconstract(noise2_model, x, BLOCK_OFFSET)
  71. end
  72. local function apply_scale2x(x)
  73. return reconstract(scale20_model,
  74. iproc.scale(x, x:size(3) * 2.0, x:size(2) * 2.0),
  75. BLOCK_OFFSET)
  76. end
  77. local function cache_do(cache, x, func)
  78. if path.exists(cache) then
  79. return image.load(cache)
  80. else
  81. x = func(x)
  82. image.save(cache, x)
  83. return x
  84. end
  85. end
  86. local function client_disconnected(handler)
  87. return not(handler.request and
  88. handler.request.connection and
  89. handler.request.connection.stream and
  90. (not handler.request.connection.stream:closed()))
  91. end
  92. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  93. function APIHandler:post()
  94. if client_disconnected(self) then
  95. self:set_status(400)
  96. self:write("client disconnected")
  97. return
  98. end
  99. local x, src = get_image(self)
  100. local scale = tonumber(self:get_argument("scale", "0"))
  101. local noise = tonumber(self:get_argument("noise", "0"))
  102. if x and valid_size(x, scale) then
  103. if USE_CACHE and (noise ~= 0 or scale ~= 0) then
  104. local hash = md5.sumhexa(src)
  105. local cache_noise1 = path.join(CACHE_DIR, hash .. "_noise1.png")
  106. local cache_noise2 = path.join(CACHE_DIR, hash .. "_noise2.png")
  107. local cache_scale = path.join(CACHE_DIR, hash .. "_scale.png")
  108. local cache_noise1_scale = path.join(CACHE_DIR, hash .. "_noise1_scale.png")
  109. local cache_noise2_scale = path.join(CACHE_DIR, hash .. "_noise2_scale.png")
  110. if noise == 1 then
  111. x = cache_do(cache_noise1, x, apply_denoise1)
  112. elseif noise == 2 then
  113. x = cache_do(cache_noise2, x, apply_denoise2)
  114. end
  115. if scale == 1 or scale == 2 then
  116. if noise == 1 then
  117. x = cache_do(cache_noise1_scale, x, apply_scale2x)
  118. elseif noise == 2 then
  119. x = cache_do(cache_noise2_scale, x, apply_scale2x)
  120. else
  121. x = cache_do(cache_scale, x, apply_scale2x)
  122. end
  123. if scale == 1 then
  124. x = iproc.scale(x,
  125. math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
  126. math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
  127. "Jinc")
  128. end
  129. end
  130. elseif noise ~= 0 or scale ~= 0 then
  131. if noise == 1 then
  132. x = apply_denoise1(x)
  133. elseif noise == 2 then
  134. x = apply_denoise2(x)
  135. end
  136. if scale == 1 then
  137. local x16 = {math.floor(x:size(3) * 1.6 + 0.5), math.floor(x:size(2) * 1.6 + 0.5)}
  138. x = apply_scale2x(x)
  139. x = iproc.scale(x, x16[1], x16[2], "Jinc")
  140. elseif scale == 2 then
  141. x = apply_scale2x(x)
  142. end
  143. end
  144. local name = uuid() .. ".png"
  145. local blob, len = image_loader.encode_png(x)
  146. self:set_header("Content-Disposition", string.format('filename="%s"', name))
  147. self:set_header("Content-Type", "image/png")
  148. self:set_header("Content-Length", string.format("%d", len))
  149. self:write(ffi.string(blob, len))
  150. else
  151. if not x then
  152. self:set_status(400)
  153. self:write("ERROR: unsupported image format.")
  154. else
  155. self:set_status(400)
  156. self:write("ERROR: image size exceeds maximum allowable size.")
  157. end
  158. end
  159. collectgarbage()
  160. end
  161. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  162. local index_ja = file.read(path.join(ROOT, "assets/index.ja.html"))
  163. local index_en = file.read(path.join(ROOT, "assets/index.html"))
  164. function FormHandler:get()
  165. local lang = self.request.headers:get("Accept-Language")
  166. if lang then
  167. local langs = utils.split(lang, ",")
  168. for i = 1, #langs do
  169. langs[i] = utils.split(langs[i], ";")[1]
  170. end
  171. if langs[1] == "ja" then
  172. self:write(index_ja)
  173. else
  174. self:write(index_en)
  175. end
  176. else
  177. self:write(index_en)
  178. end
  179. end
  180. local app = turbo.web.Application:new(
  181. {
  182. {"^/$", FormHandler},
  183. {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")},
  184. {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")},
  185. {"^/api$", APIHandler},
  186. }
  187. )
  188. app:listen(8812, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  189. turbo.ioloop.instance():start()