web.lua 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. local ROOT = path.dirname(__FILE__)
  4. package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
  5. _G.TURBO_SSL = true
  6. require 'w2nn'
  7. local uuid = require 'uuid'
  8. local ffi = require 'ffi'
  9. local md5 = require 'md5'
  10. local iproc = require 'iproc'
  11. local reconstruct = require 'reconstruct'
  12. local image_loader = require 'image_loader'
  13. -- Notes: turbo and xlua has different implementation of string:split().
  14. -- Therefore, string:split() has conflict issue.
  15. -- In this script, use turbo's string:split().
  16. local turbo = require 'turbo'
  17. local cmd = torch.CmdLine()
  18. cmd:text()
  19. cmd:text("waifu2x-api")
  20. cmd:text("Options:")
  21. cmd:option("-port", 8812, 'listen port')
  22. cmd:option("-gpu", 1, 'Device ID')
  23. cmd:option("-thread", -1, 'number of CPU threads')
  24. local opt = cmd:parse(arg)
  25. cutorch.setDevice(opt.gpu)
  26. torch.setdefaulttensortype('torch.FloatTensor')
  27. if opt.thread > 0 then
  28. torch.setnumthreads(opt.thread)
  29. end
  30. if cudnn then
  31. cudnn.fastest = true
  32. cudnn.benchmark = false
  33. end
  34. local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
  35. local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
  36. local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
  37. local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
  38. local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  39. local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  40. local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
  41. local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
  42. local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
  43. local CACHE_DIR = path.join(ROOT, "cache")
  44. local MAX_NOISE_IMAGE = 2560 * 2560
  45. local MAX_SCALE_IMAGE = 1280 * 1280
  46. local CURL_OPTIONS = {
  47. request_timeout = 15,
  48. connect_timeout = 10,
  49. allow_redirects = true,
  50. max_redirects = 2
  51. }
  52. local CURL_MAX_SIZE = 2 * 1024 * 1024
  53. local function valid_size(x, scale)
  54. if scale == 0 then
  55. return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE
  56. else
  57. return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE
  58. end
  59. end
  60. local function cache_url(url)
  61. local hash = md5.sumhexa(url)
  62. local cache_file = path.join(CACHE_DIR, "url_" .. hash)
  63. if path.exists(cache_file) then
  64. return image_loader.load_float(cache_file)
  65. else
  66. local res = coroutine.yield(
  67. turbo.async.HTTPClient({verify_ca=false},
  68. nil,
  69. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  70. )
  71. if res.code == 200 then
  72. local content_type = res.headers:get("Content-Type", true)
  73. if type(content_type) == "table" then
  74. content_type = content_type[1]
  75. end
  76. if content_type and content_type:find("image") then
  77. local fp = io.open(cache_file, "wb")
  78. local blob = res.body
  79. fp:write(blob)
  80. fp:close()
  81. return image_loader.decode_float(blob)
  82. end
  83. end
  84. end
  85. return nil, nil, nil
  86. end
  87. local function get_image(req)
  88. local file = req:get_argument("file", "")
  89. local url = req:get_argument("url", "")
  90. if file and file:len() > 0 then
  91. return image_loader.decode_float(file)
  92. elseif url and url:len() > 0 then
  93. return cache_url(url)
  94. end
  95. return nil, nil, nil
  96. end
  97. local function cleanup_model(model)
  98. if CLEANUP_MODEL then
  99. w2nn.cleanup_model(model) -- release GPU memory
  100. end
  101. end
  102. local function convert(x, options)
  103. local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
  104. if path.exists(cache_file) then
  105. return image.load(cache_file)
  106. else
  107. if options.style == "art" then
  108. if options.method == "scale" then
  109. x = reconstruct.scale(art_scale2_model, 2.0, x)
  110. cleanup_model(art_scale2_model)
  111. elseif options.method == "noise1" then
  112. x = reconstruct.image(art_noise1_model, x)
  113. cleanup_model(art_noise1_model)
  114. else -- options.method == "noise2"
  115. x = reconstruct.image(art_noise2_model, x)
  116. cleanup_model(art_noise2_model)
  117. end
  118. else -- photo
  119. if options.method == "scale" then
  120. x = reconstruct.scale(photo_scale2_model, 2.0, x)
  121. cleanup_model(photo_scale2_model)
  122. elseif options.method == "noise1" then
  123. x = reconstruct.image(photo_noise1_model, x)
  124. cleanup_model(photo_noise1_model)
  125. elseif options.method == "noise2" then
  126. x = reconstruct.image(photo_noise2_model, x)
  127. cleanup_model(photo_noise2_model)
  128. end
  129. end
  130. image.save(cache_file, x)
  131. return x
  132. end
  133. end
  134. local function client_disconnected(handler)
  135. return not(handler.request and
  136. handler.request.connection and
  137. handler.request.connection.stream and
  138. (not handler.request.connection.stream:closed()))
  139. end
  140. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  141. function APIHandler:post()
  142. if client_disconnected(self) then
  143. self:set_status(400)
  144. self:write("client disconnected")
  145. return
  146. end
  147. local x, alpha, blob = get_image(self)
  148. local scale = tonumber(self:get_argument("scale", "0"))
  149. local noise = tonumber(self:get_argument("noise", "0"))
  150. local white_noise = tonumber(self:get_argument("white_noise", "0"))
  151. local style = self:get_argument("style", "art")
  152. if style ~= "art" then
  153. style = "photo" -- style must be art or photo
  154. end
  155. if x and valid_size(x, scale) then
  156. if (noise ~= 0 or scale ~= 0) then
  157. local hash = md5.sumhexa(blob)
  158. if noise == 1 then
  159. x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash})
  160. elseif noise == 2 then
  161. x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash})
  162. end
  163. if scale == 1 or scale == 2 then
  164. if noise == 1 then
  165. x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash})
  166. elseif noise == 2 then
  167. x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash})
  168. else
  169. x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash})
  170. end
  171. if scale == 1 then
  172. x = iproc.scale_with_gamma22(x,
  173. math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
  174. math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
  175. "Jinc")
  176. end
  177. end
  178. if white_noise == 1 then
  179. x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0})
  180. end
  181. end
  182. local name = uuid() .. ".png"
  183. local blob = image_loader.encode_png(x, alpha)
  184. self:set_header("Content-Disposition", string.format('filename="%s"', name))
  185. self:set_header("Content-Type", "image/png")
  186. self:set_header("Content-Length", string.format("%d", #blob))
  187. self:write(blob)
  188. else
  189. if not x then
  190. self:set_status(400)
  191. self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
  192. else
  193. self:set_status(400)
  194. self:write("ERROR: image size exceeds maximum allowable size.")
  195. end
  196. end
  197. collectgarbage()
  198. end
  199. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  200. local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
  201. local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
  202. local index_en = file.read(path.join(ROOT, "assets", "index.html"))
  203. function FormHandler:get()
  204. local lang = self.request.headers:get("Accept-Language")
  205. if lang then
  206. local langs = utils.split(lang, ",")
  207. for i = 1, #langs do
  208. langs[i] = utils.split(langs[i], ";")[1]
  209. end
  210. if langs[1] == "ja" then
  211. self:write(index_ja)
  212. elseif langs[1] == "ru" then
  213. self:write(index_ru)
  214. else
  215. self:write(index_en)
  216. end
  217. else
  218. self:write(index_en)
  219. end
  220. end
  221. turbo.log.categories = {
  222. ["success"] = true,
  223. ["notice"] = false,
  224. ["warning"] = true,
  225. ["error"] = true,
  226. ["debug"] = false,
  227. ["development"] = false
  228. }
  229. local app = turbo.web.Application:new(
  230. {
  231. {"^/$", FormHandler},
  232. {"^/style.css", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "style.css")},
  233. {"^/ui.js", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "ui.js")},
  234. {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")},
  235. {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")},
  236. {"^/index.ru.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ru.html")},
  237. {"^/api$", APIHandler},
  238. }
  239. )
  240. app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  241. turbo.ioloop.instance():start()