web.lua 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. local ROOT = path.dirname(__FILE__)
  4. package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
  5. _G.TURBO_SSL = true
  6. require 'w2nn'
  7. local uuid = require 'uuid'
  8. local ffi = require 'ffi'
  9. local md5 = require 'md5'
  10. local iproc = require 'iproc'
  11. local reconstruct = require 'reconstruct'
  12. local image_loader = require 'image_loader'
  13. -- Notes: turbo and xlua has different implementation of string:split().
  14. -- Therefore, string:split() has conflict issue.
  15. -- In this script, use turbo's string:split().
  16. local turbo = require 'turbo'
  17. local cmd = torch.CmdLine()
  18. cmd:text()
  19. cmd:text("waifu2x-api")
  20. cmd:text("Options:")
  21. cmd:option("-port", 8812, 'listen port')
  22. cmd:option("-gpu", 1, 'Device ID')
  23. cmd:option("-thread", -1, 'number of CPU threads')
  24. local opt = cmd:parse(arg)
  25. cutorch.setDevice(opt.gpu)
  26. torch.setdefaulttensortype('torch.FloatTensor')
  27. if opt.thread > 0 then
  28. torch.setnumthreads(opt.thread)
  29. end
  30. if cudnn then
  31. cudnn.fastest = true
  32. cudnn.benchmark = false
  33. end
  34. local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
  35. local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
  36. local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
  37. local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
  38. local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  39. local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  40. local CACHE_DIR = path.join(ROOT, "cache")
  41. local MAX_NOISE_IMAGE = 2560 * 2560
  42. local MAX_SCALE_IMAGE = 1280 * 1280
  43. local CURL_OPTIONS = {
  44. request_timeout = 15,
  45. connect_timeout = 10,
  46. allow_redirects = true,
  47. max_redirects = 2
  48. }
  49. local CURL_MAX_SIZE = 2 * 1024 * 1024
  50. local function valid_size(x, scale)
  51. if scale == 0 then
  52. return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE
  53. else
  54. return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE
  55. end
  56. end
  57. local function cache_url(url)
  58. local hash = md5.sumhexa(url)
  59. local cache_file = path.join(CACHE_DIR, "url_" .. hash)
  60. if path.exists(cache_file) then
  61. return image_loader.load_float(cache_file)
  62. else
  63. local res = coroutine.yield(
  64. turbo.async.HTTPClient({verify_ca=false},
  65. nil,
  66. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  67. )
  68. if res.code == 200 then
  69. local content_type = res.headers:get("Content-Type", true)
  70. if type(content_type) == "table" then
  71. content_type = content_type[1]
  72. end
  73. if content_type and content_type:find("image") then
  74. local fp = io.open(cache_file, "wb")
  75. local blob = res.body
  76. fp:write(blob)
  77. fp:close()
  78. return image_loader.decode_float(blob)
  79. end
  80. end
  81. end
  82. return nil, nil, nil
  83. end
  84. local function get_image(req)
  85. local file = req:get_argument("file", "")
  86. local url = req:get_argument("url", "")
  87. if file and file:len() > 0 then
  88. return image_loader.decode_float(file)
  89. elseif url and url:len() > 0 then
  90. return cache_url(url)
  91. end
  92. return nil, nil, nil
  93. end
  94. local function convert(x, options)
  95. local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
  96. if path.exists(cache_file) then
  97. return image.load(cache_file)
  98. else
  99. if options.style == "art" then
  100. if options.method == "scale" then
  101. x = reconstruct.scale(art_scale2_model, 2.0, x)
  102. w2nn.cleanup_model(art_scale2_model)
  103. elseif options.method == "noise1" then
  104. x = reconstruct.image(art_noise1_model, x)
  105. w2nn.cleanup_model(art_noise1_model)
  106. else -- options.method == "noise2"
  107. x = reconstruct.image(art_noise2_model, x)
  108. w2nn.cleanup_model(art_noise2_model)
  109. end
  110. else -- photo
  111. x = reconstruct.scale(photo_scale2_model, 2.0, x)
  112. w2nn.cleanup_model(photo_scale2_model)
  113. end
  114. image.save(cache_file, x)
  115. return x
  116. end
  117. end
  118. local function client_disconnected(handler)
  119. return not(handler.request and
  120. handler.request.connection and
  121. handler.request.connection.stream and
  122. (not handler.request.connection.stream:closed()))
  123. end
  124. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  125. function APIHandler:post()
  126. if client_disconnected(self) then
  127. self:set_status(400)
  128. self:write("client disconnected")
  129. return
  130. end
  131. local x, alpha, blob = get_image(self)
  132. local scale = tonumber(self:get_argument("scale", "0"))
  133. local noise = tonumber(self:get_argument("noise", "0"))
  134. local white_noise = tonumber(self:get_argument("white_noise", "0"))
  135. local style = self:get_argument("style", "art")
  136. if style ~= "art" then
  137. style = "photo" -- style must be art or photo
  138. end
  139. if x and valid_size(x, scale) then
  140. if (noise ~= 0 or scale ~= 0) then
  141. local hash = md5.sumhexa(blob)
  142. if noise == 1 then
  143. x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash})
  144. elseif noise == 2 then
  145. x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash})
  146. end
  147. if scale == 1 or scale == 2 then
  148. if noise == 1 then
  149. x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash})
  150. elseif noise == 2 then
  151. x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash})
  152. else
  153. x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash})
  154. end
  155. if scale == 1 then
  156. x = iproc.scale_with_gamma22(x,
  157. math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
  158. math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
  159. "Jinc")
  160. end
  161. end
  162. if white_noise == 1 then
  163. x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0})
  164. end
  165. end
  166. local name = uuid() .. ".png"
  167. local blob, len = image_loader.encode_png(x, alpha)
  168. self:set_header("Content-Disposition", string.format('filename="%s"', name))
  169. self:set_header("Content-Type", "image/png")
  170. self:set_header("Content-Length", string.format("%d", len))
  171. self:write(ffi.string(blob, len))
  172. else
  173. if not x then
  174. self:set_status(400)
  175. self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
  176. else
  177. self:set_status(400)
  178. self:write("ERROR: image size exceeds maximum allowable size.")
  179. end
  180. end
  181. collectgarbage()
  182. end
  183. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  184. local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
  185. local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
  186. local index_en = file.read(path.join(ROOT, "assets", "index.html"))
  187. function FormHandler:get()
  188. local lang = self.request.headers:get("Accept-Language")
  189. if lang then
  190. local langs = utils.split(lang, ",")
  191. for i = 1, #langs do
  192. langs[i] = utils.split(langs[i], ";")[1]
  193. end
  194. if langs[1] == "ja" then
  195. self:write(index_ja)
  196. elseif langs[1] == "ru" then
  197. self:write(index_ru)
  198. else
  199. self:write(index_en)
  200. end
  201. else
  202. self:write(index_en)
  203. end
  204. end
  205. turbo.log.categories = {
  206. ["success"] = true,
  207. ["notice"] = false,
  208. ["warning"] = true,
  209. ["error"] = true,
  210. ["debug"] = false,
  211. ["development"] = false
  212. }
  213. local app = turbo.web.Application:new(
  214. {
  215. {"^/$", FormHandler},
  216. {"^/style.css", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "style.css")},
  217. {"^/ui.js", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "ui.js")},
  218. {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")},
  219. {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")},
  220. {"^/index.ru.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ru.html")},
  221. {"^/api$", APIHandler},
  222. }
  223. )
  224. app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  225. turbo.ioloop.instance():start()