web.lua 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. local ROOT = path.dirname(__FILE__)
  4. package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
  5. _G.TURBO_SSL = true
  6. require 'w2nn'
  7. local uuid = require 'uuid'
  8. local ffi = require 'ffi'
  9. local md5 = require 'md5'
  10. local iproc = require 'iproc'
  11. local reconstruct = require 'reconstruct'
  12. local image_loader = require 'image_loader'
  13. local alpha_util = require 'alpha_util'
  14. -- Notes: turbo and xlua has different implementation of string:split().
  15. -- Therefore, string:split() has conflict issue.
  16. -- In this script, use turbo's string:split().
  17. local turbo = require 'turbo'
  18. local cmd = torch.CmdLine()
  19. cmd:text()
  20. cmd:text("waifu2x-api")
  21. cmd:text("Options:")
  22. cmd:option("-port", 8812, 'listen port')
  23. cmd:option("-gpu", 1, 'Device ID')
  24. cmd:option("-thread", -1, 'number of CPU threads')
  25. local opt = cmd:parse(arg)
  26. cutorch.setDevice(opt.gpu)
  27. torch.setdefaulttensortype('torch.FloatTensor')
  28. if opt.thread > 0 then
  29. torch.setnumthreads(opt.thread)
  30. end
  31. if cudnn then
  32. cudnn.fastest = true
  33. cudnn.benchmark = false
  34. end
  35. local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
  36. local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
  37. local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
  38. local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
  39. local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  40. --local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  41. --local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
  42. --local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
  43. local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
  44. local CACHE_DIR = path.join(ROOT, "cache")
  45. local MAX_NOISE_IMAGE = 2560 * 2560
  46. local MAX_SCALE_IMAGE = 1280 * 1280
  47. local CURL_OPTIONS = {
  48. request_timeout = 15,
  49. connect_timeout = 10,
  50. allow_redirects = true,
  51. max_redirects = 2
  52. }
  53. local CURL_MAX_SIZE = 2 * 1024 * 1024
  54. local function valid_size(x, scale)
  55. if scale == 0 then
  56. return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE
  57. else
  58. return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE
  59. end
  60. end
  61. local function cache_url(url)
  62. local hash = md5.sumhexa(url)
  63. local cache_file = path.join(CACHE_DIR, "url_" .. hash)
  64. if path.exists(cache_file) then
  65. return image_loader.load_float(cache_file)
  66. else
  67. local res = coroutine.yield(
  68. turbo.async.HTTPClient({verify_ca=false},
  69. nil,
  70. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  71. )
  72. if res.code == 200 then
  73. local content_type = res.headers:get("Content-Type", true)
  74. if type(content_type) == "table" then
  75. content_type = content_type[1]
  76. end
  77. if content_type and content_type:find("image") then
  78. local fp = io.open(cache_file, "wb")
  79. local blob = res.body
  80. fp:write(blob)
  81. fp:close()
  82. return image_loader.decode_float(blob)
  83. end
  84. end
  85. end
  86. return nil, nil, nil
  87. end
  88. local function get_image(req)
  89. local file = req:get_argument("file", "")
  90. local url = req:get_argument("url", "")
  91. if file and file:len() > 0 then
  92. return image_loader.decode_float(file)
  93. elseif url and url:len() > 0 then
  94. return cache_url(url)
  95. end
  96. return nil, nil, nil
  97. end
  98. local function cleanup_model(model)
  99. if CLEANUP_MODEL then
  100. w2nn.cleanup_model(model) -- release GPU memory
  101. end
  102. end
  103. local function convert(x, alpha, options)
  104. local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
  105. local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
  106. local alpha_orig = alpha
  107. if path.exists(alpha_cache_file) then
  108. alpha = image_loader.load_float(alpha_cache_file)
  109. if alpha:dim() == 2 then
  110. alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
  111. end
  112. if alpha:size(1) == 3 then
  113. alpha = image.rgb2y(alpha)
  114. end
  115. end
  116. if path.exists(cache_file) then
  117. x = image_loader.load_float(cache_file)
  118. return x, alpha
  119. else
  120. if options.style == "art" then
  121. if options.border then
  122. x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
  123. end
  124. if options.method == "scale" then
  125. x = reconstruct.scale(art_scale2_model, 2.0, x)
  126. if alpha then
  127. if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
  128. alpha = reconstruct.scale(art_scale2_model, 2.0, alpha)
  129. image_loader.save_png(alpha_cache_file, alpha)
  130. end
  131. end
  132. cleanup_model(art_scale2_model)
  133. elseif options.method == "noise1" then
  134. x = reconstruct.image(art_noise1_model, x)
  135. cleanup_model(art_noise1_model)
  136. else -- options.method == "noise2"
  137. x = reconstruct.image(art_noise2_model, x)
  138. cleanup_model(art_noise2_model)
  139. end
  140. else --[[photo
  141. if options.border then
  142. x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
  143. end
  144. if options.method == "scale" then
  145. x = reconstruct.scale(photo_scale2_model, 2.0, x)
  146. if alpha then
  147. if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
  148. alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha)
  149. image_loader.save_png(alpha_cache_file, alpha)
  150. end
  151. end
  152. cleanup_model(photo_scale2_model)
  153. elseif options.method == "noise1" then
  154. x = reconstruct.image(photo_noise1_model, x)
  155. cleanup_model(photo_noise1_model)
  156. elseif options.method == "noise2" then
  157. x = reconstruct.image(photo_noise2_model, x)
  158. cleanup_model(photo_noise2_model)
  159. end
  160. --]]
  161. end
  162. image_loader.save_png(cache_file, x)
  163. return x, alpha
  164. end
  165. end
  166. local function client_disconnected(handler)
  167. return not(handler.request and
  168. handler.request.connection and
  169. handler.request.connection.stream and
  170. (not handler.request.connection.stream:closed()))
  171. end
  172. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  173. function APIHandler:post()
  174. if client_disconnected(self) then
  175. self:set_status(400)
  176. self:write("client disconnected")
  177. return
  178. end
  179. local x, alpha, blob = get_image(self)
  180. local scale = tonumber(self:get_argument("scale", "0"))
  181. local noise = tonumber(self:get_argument("noise", "0"))
  182. local style = self:get_argument("style", "art")
  183. local download = (self:get_argument("download", "")):len()
  184. if style ~= "art" then
  185. style = "photo" -- style must be art or photo
  186. end
  187. if x and valid_size(x, scale) then
  188. if (noise ~= 0 or scale ~= 0) then
  189. local hash = md5.sumhexa(blob)
  190. local alpha_prefix = style .. "_" .. hash .. "_alpha"
  191. local border = false
  192. if scale ~= 0 and alpha then
  193. border = true
  194. end
  195. if noise == 1 then
  196. x = convert(x, alpha, {method = "noise1", style = style,
  197. prefix = style .. "_noise1_" .. hash,
  198. alpha_prefix = alpha_prefix, border = border})
  199. border = false
  200. elseif noise == 2 then
  201. x = convert(x, alpha, {method = "noise2", style = style,
  202. prefix = style .. "_noise2_" .. hash,
  203. alpha_prefix = alpha_prefix, border = border})
  204. border = false
  205. end
  206. if scale == 1 or scale == 2 then
  207. local prefix
  208. if noise == 1 then
  209. prefix = style .. "_noise1_scale_" .. hash
  210. elseif noise == 2 then
  211. prefix = style .. "_noise2_scale_" .. hash
  212. else
  213. prefix = style .. "_scale_" .. hash
  214. end
  215. x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix, alpha_prefix = alpha_prefix, border = border})
  216. if scale == 1 then
  217. x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
  218. end
  219. end
  220. end
  221. local name = uuid() .. ".png"
  222. local blob = image_loader.encode_png(alpha_util.composite(x, alpha))
  223. self:set_header("Content-Disposition", string.format('filename="%s"', name))
  224. self:set_header("Content-Length", string.format("%d", #blob))
  225. if download > 0 then
  226. self:set_header("Content-Type", "application/octet-stream")
  227. else
  228. self:set_header("Content-Type", "image/png")
  229. end
  230. self:write(blob)
  231. else
  232. if not x then
  233. self:set_status(400)
  234. self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
  235. else
  236. self:set_status(400)
  237. self:write("ERROR: image size exceeds maximum allowable size.")
  238. end
  239. end
  240. collectgarbage()
  241. end
  242. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  243. local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
  244. local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
  245. local index_en = file.read(path.join(ROOT, "assets", "index.html"))
  246. function FormHandler:get()
  247. local lang = self.request.headers:get("Accept-Language")
  248. if lang then
  249. local langs = utils.split(lang, ",")
  250. for i = 1, #langs do
  251. langs[i] = utils.split(langs[i], ";")[1]
  252. end
  253. if langs[1] == "ja" then
  254. self:write(index_ja)
  255. elseif langs[1] == "ru" then
  256. self:write(index_ru)
  257. else
  258. self:write(index_en)
  259. end
  260. else
  261. self:write(index_en)
  262. end
  263. end
  264. turbo.log.categories = {
  265. ["success"] = true,
  266. ["notice"] = false,
  267. ["warning"] = true,
  268. ["error"] = true,
  269. ["debug"] = false,
  270. ["development"] = false
  271. }
  272. local app = turbo.web.Application:new(
  273. {
  274. {"^/$", FormHandler},
  275. {"^/style.css", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "style.css")},
  276. {"^/ui.js", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "ui.js")},
  277. {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")},
  278. {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")},
  279. {"^/index.ru.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ru.html")},
  280. {"^/api$", APIHandler},
  281. }
  282. )
  283. app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  284. turbo.ioloop.instance():start()