web.lua 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  2. package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
  3. _G.TURBO_SSL = true
  4. require 'pl'
  5. require 'w2nn'
  6. local turbo = require 'turbo'
  7. local uuid = require 'uuid'
  8. local ffi = require 'ffi'
  9. local md5 = require 'md5'
  10. local iproc = require 'iproc'
  11. local reconstruct = require 'reconstruct'
  12. local image_loader = require 'image_loader'
  13. local cmd = torch.CmdLine()
  14. cmd:text()
  15. cmd:text("waifu2x-api")
  16. cmd:text("Options:")
  17. cmd:option("-port", 8812, 'listen port')
  18. cmd:option("-gpu", 1, 'Device ID')
  19. cmd:option("-thread", -1, 'number of CPU threads')
  20. local opt = cmd:parse(arg)
  21. cutorch.setDevice(opt.gpu)
  22. torch.setdefaulttensortype('torch.FloatTensor')
  23. if opt.thread > 0 then
  24. torch.setnumthreads(opt.thread)
  25. end
  26. if cudnn then
  27. cudnn.fastest = true
  28. cudnn.benchmark = false
  29. end
  30. local MODEL_DIR = "./models/anime_style_art_rgb"
  31. local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
  32. local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
  33. local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  34. local USE_CACHE = true
  35. local CACHE_DIR = "./cache"
  36. local MAX_NOISE_IMAGE = 2560 * 2560
  37. local MAX_SCALE_IMAGE = 1280 * 1280
  38. local CURL_OPTIONS = {
  39. request_timeout = 15,
  40. connect_timeout = 10,
  41. allow_redirects = true,
  42. max_redirects = 2
  43. }
  44. local CURL_MAX_SIZE = 2 * 1024 * 1024
  45. local function valid_size(x, scale)
  46. if scale == 0 then
  47. return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE
  48. else
  49. return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE
  50. end
  51. end
  52. local function apply_denoise1(x)
  53. return reconstruct.image(noise1_model, x)
  54. end
  55. local function apply_denoise2(x)
  56. return reconstruct.image(noise2_model, x)
  57. end
  58. local function apply_scale2x(x)
  59. return reconstruct.scale(scale20_model, 2.0, x)
  60. end
  61. local function cache_url(url)
  62. local hash = md5.sumhexa(url)
  63. local cache_file = path.join(CACHE_DIR, "url_" .. hash)
  64. if path.exists(cache_file) then
  65. return image_loader.load_float(cache_file)
  66. else
  67. local res = coroutine.yield(
  68. turbo.async.HTTPClient({verify_ca=false},
  69. nil,
  70. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  71. )
  72. if res.code == 200 then
  73. local content_type = res.headers:get("Content-Type", true)
  74. if type(content_type) == "table" then
  75. content_type = content_type[1]
  76. end
  77. if content_type and content_type:find("image") then
  78. local fp = io.open(cache_file, "wb")
  79. local blob = res.body
  80. fp:write(blob)
  81. fp:close()
  82. return image_loader.decode_float(blob)
  83. end
  84. end
  85. end
  86. return nil, nil, nil
  87. end
  88. local function cache_do(cache, x, func)
  89. if path.exists(cache) then
  90. return image.load(cache)
  91. else
  92. x = func(x)
  93. image.save(cache, x)
  94. return x
  95. end
  96. end
  97. local function get_image(req)
  98. local file = req:get_argument("file", "")
  99. local url = req:get_argument("url", "")
  100. local blob = nil
  101. local img = nil
  102. local alpha = nil
  103. if file and file:len() > 0 then
  104. blob = file
  105. return image_loader.decode_float(blob)
  106. elseif url and url:len() > 0 then
  107. return cache_url(url)
  108. end
  109. return nil, nil, nil
  110. end
  111. local function client_disconnected(handler)
  112. return not(handler.request and
  113. handler.request.connection and
  114. handler.request.connection.stream and
  115. (not handler.request.connection.stream:closed()))
  116. end
  117. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  118. function APIHandler:post()
  119. if client_disconnected(self) then
  120. self:set_status(400)
  121. self:write("client disconnected")
  122. return
  123. end
  124. local x, alpha, src = get_image(self)
  125. local scale = tonumber(self:get_argument("scale", "0"))
  126. local noise = tonumber(self:get_argument("noise", "0"))
  127. if x and valid_size(x, scale) then
  128. if USE_CACHE and (noise ~= 0 or scale ~= 0) then
  129. local hash = md5.sumhexa(src)
  130. local cache_noise1 = path.join(CACHE_DIR, hash .. "_noise1.png")
  131. local cache_noise2 = path.join(CACHE_DIR, hash .. "_noise2.png")
  132. local cache_scale = path.join(CACHE_DIR, hash .. "_scale.png")
  133. local cache_noise1_scale = path.join(CACHE_DIR, hash .. "_noise1_scale.png")
  134. local cache_noise2_scale = path.join(CACHE_DIR, hash .. "_noise2_scale.png")
  135. if noise == 1 then
  136. x = cache_do(cache_noise1, x, apply_denoise1)
  137. elseif noise == 2 then
  138. x = cache_do(cache_noise2, x, apply_denoise2)
  139. end
  140. if scale == 1 or scale == 2 then
  141. if noise == 1 then
  142. x = cache_do(cache_noise1_scale, x, apply_scale2x)
  143. elseif noise == 2 then
  144. x = cache_do(cache_noise2_scale, x, apply_scale2x)
  145. else
  146. x = cache_do(cache_scale, x, apply_scale2x)
  147. end
  148. if scale == 1 then
  149. x = iproc.scale(x,
  150. math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
  151. math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
  152. "Jinc")
  153. end
  154. end
  155. elseif noise ~= 0 or scale ~= 0 then
  156. if noise == 1 then
  157. x = apply_denoise1(x)
  158. elseif noise == 2 then
  159. x = apply_denoise2(x)
  160. end
  161. if scale == 1 then
  162. local x16 = {math.floor(x:size(3) * 1.6 + 0.5), math.floor(x:size(2) * 1.6 + 0.5)}
  163. x = apply_scale2x(x)
  164. x = iproc.scale(x, x16[1], x16[2], "Jinc")
  165. elseif scale == 2 then
  166. x = apply_scale2x(x)
  167. end
  168. end
  169. local name = uuid() .. ".png"
  170. local blob, len = image_loader.encode_png(x, alpha)
  171. self:set_header("Content-Disposition", string.format('filename="%s"', name))
  172. self:set_header("Content-Type", "image/png")
  173. self:set_header("Content-Length", string.format("%d", len))
  174. self:write(ffi.string(blob, len))
  175. else
  176. if not x then
  177. self:set_status(400)
  178. self:write("ERROR: unsupported image format.")
  179. else
  180. self:set_status(400)
  181. self:write("ERROR: image size exceeds maximum allowable size.")
  182. end
  183. end
  184. collectgarbage()
  185. end
  186. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  187. local index_ja = file.read("./assets/index.ja.html")
  188. local index_ru = file.read("./assets/index.ru.html")
  189. local index_en = file.read("./assets/index.html")
  190. function FormHandler:get()
  191. local lang = self.request.headers:get("Accept-Language")
  192. if lang then
  193. local langs = utils.split(lang, ",")
  194. for i = 1, #langs do
  195. langs[i] = utils.split(langs[i], ";")[1]
  196. end
  197. if langs[1] == "ja" then
  198. self:write(index_ja)
  199. elseif langs[1] == "ru" then
  200. self:write(index_ru)
  201. else
  202. self:write(index_en)
  203. end
  204. else
  205. self:write(index_en)
  206. end
  207. end
  208. turbo.log.categories = {
  209. ["success"] = true,
  210. ["notice"] = false,
  211. ["warning"] = true,
  212. ["error"] = true,
  213. ["debug"] = false,
  214. ["development"] = false
  215. }
  216. local app = turbo.web.Application:new(
  217. {
  218. {"^/$", FormHandler},
  219. {"^/index.html", turbo.web.StaticFileHandler, path.join("./assets", "index.html")},
  220. {"^/index.ja.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ja.html")},
  221. {"^/index.ru.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ru.html")},
  222. {"^/api$", APIHandler},
  223. }
  224. )
  225. app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  226. turbo.ioloop.instance():start()