web.lua 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. local ROOT = path.dirname(__FILE__)
  4. package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
  5. _G.TURBO_SSL = true
  6. require 'w2nn'
  7. local uuid = require 'uuid'
  8. local ffi = require 'ffi'
  9. local md5 = require 'md5'
  10. local iproc = require 'iproc'
  11. local reconstruct = require 'reconstruct'
  12. local image_loader = require 'image_loader'
  13. local alpha_util = require 'alpha_util'
  14. local gm = require 'graphicsmagick'
  15. -- Note: turbo and xlua has different implementation of string:split().
  16. -- Therefore, string:split() has conflict issue.
  17. -- In this script, use turbo's string:split().
  18. local turbo = require 'turbo'
  19. local cmd = torch.CmdLine()
  20. cmd:text()
  21. cmd:text("waifu2x-api")
  22. cmd:text("Options:")
  23. cmd:option("-port", 8812, 'listen port')
  24. cmd:option("-gpu", 1, 'Device ID')
  25. cmd:option("-upsampling_filter", "Box", 'Upsampling filter (for dev)')
  26. cmd:option("-crop_size", 128, 'patch size per process')
  27. cmd:option("-thread", -1, 'number of CPU threads')
  28. local opt = cmd:parse(arg)
  29. cutorch.setDevice(opt.gpu)
  30. torch.setdefaulttensortype('torch.FloatTensor')
  31. if opt.thread > 0 then
  32. torch.setnumthreads(opt.thread)
  33. end
  34. if cudnn then
  35. cudnn.fastest = true
  36. cudnn.benchmark = false
  37. end
  38. local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
  39. local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
  40. local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  41. local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
  42. local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
  43. local art_noise3_model = torch.load(path.join(ART_MODEL_DIR, "noise3_model.t7"), "ascii")
  44. local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  45. local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
  46. local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
  47. local photo_noise3_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise3_model.t7"), "ascii")
  48. local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
  49. local CACHE_DIR = path.join(ROOT, "cache")
  50. local MAX_NOISE_IMAGE = 2560 * 2560
  51. local MAX_SCALE_IMAGE = 1280 * 1280
  52. local CURL_OPTIONS = {
  53. request_timeout = 60,
  54. connect_timeout = 60,
  55. allow_redirects = true,
  56. max_redirects = 2
  57. }
  58. local CURL_MAX_SIZE = 3 * 1024 * 1024
  59. local function valid_size(x, scale)
  60. if scale == 0 then
  61. return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE
  62. else
  63. return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE
  64. end
  65. end
  66. local function cache_url(url)
  67. local hash = md5.sumhexa(url)
  68. local cache_file = path.join(CACHE_DIR, "url_" .. hash)
  69. if path.exists(cache_file) then
  70. return image_loader.load_float(cache_file)
  71. else
  72. local res = coroutine.yield(
  73. turbo.async.HTTPClient({verify_ca=false},
  74. nil,
  75. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  76. )
  77. if res.code == 200 then
  78. local content_type = res.headers:get("Content-Type", true)
  79. if type(content_type) == "table" then
  80. content_type = content_type[1]
  81. end
  82. if content_type and content_type:find("image") then
  83. local fp = io.open(cache_file, "wb")
  84. local blob = res.body
  85. fp:write(blob)
  86. fp:close()
  87. return image_loader.decode_float(blob)
  88. end
  89. end
  90. end
  91. return nil, nil, nil
  92. end
  93. local function get_image(req)
  94. local file_info = req:get_arguments("file")
  95. local url = req:get_argument("url", "")
  96. local file = nil
  97. local filename = nil
  98. if file_info and #file_info == 1 then
  99. file = file_info[1][1]
  100. local disp = file_info[1]["content-disposition"]
  101. if disp and disp["filename"] then
  102. filename = path.basename(disp["filename"])
  103. end
  104. end
  105. if file and file:len() > 0 then
  106. local x, alpha, blob = image_loader.decode_float(file)
  107. return x, alpha, blob, filename
  108. elseif url and url:len() > 0 then
  109. local x, alpha, blob = cache_url(url)
  110. return x, alpha, blob, filename
  111. end
  112. return nil, nil, nil, nil
  113. end
  114. local function cleanup_model(model)
  115. if CLEANUP_MODEL then
  116. model:clearState() -- release GPU memory
  117. end
  118. end
  119. local function convert(x, alpha, options)
  120. local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
  121. local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
  122. local alpha_orig = alpha
  123. if path.exists(alpha_cache_file) then
  124. alpha = image_loader.load_float(alpha_cache_file)
  125. if alpha:dim() == 2 then
  126. alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
  127. end
  128. if alpha:size(1) == 3 then
  129. alpha = image.rgb2y(alpha)
  130. end
  131. end
  132. if path.exists(cache_file) then
  133. x = image_loader.load_float(cache_file)
  134. return x, alpha
  135. else
  136. if options.style == "art" then
  137. if options.border then
  138. x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
  139. end
  140. if options.method == "scale" then
  141. x = reconstruct.scale(art_scale2_model, 2.0, x,
  142. opt.crop_size, opt.upsampling_filter)
  143. if alpha then
  144. if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
  145. alpha = reconstruct.scale(art_scale2_model, 2.0, alpha,
  146. opt.crop_size, opt.upsampling_filter)
  147. image_loader.save_png(alpha_cache_file, alpha)
  148. end
  149. end
  150. cleanup_model(art_scale2_model)
  151. elseif options.method == "noise1" then
  152. x = reconstruct.image(art_noise1_model, x)
  153. cleanup_model(art_noise1_model)
  154. elseif options.method == "noise2" then
  155. x = reconstruct.image(art_noise2_model, x)
  156. cleanup_model(art_noise2_model)
  157. elseif options.method == "noise3" then
  158. x = reconstruct.image(art_noise3_model, x)
  159. cleanup_model(art_noise3_model)
  160. end
  161. else -- photo
  162. if options.border then
  163. x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
  164. end
  165. if options.method == "scale" then
  166. x = reconstruct.scale(photo_scale2_model, 2.0, x,
  167. opt.crop_size, opt.upsampling_filter)
  168. if alpha then
  169. if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
  170. alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha,
  171. opt.crop_size, opt.upsampling_filter)
  172. image_loader.save_png(alpha_cache_file, alpha)
  173. end
  174. end
  175. cleanup_model(photo_scale2_model)
  176. elseif options.method == "noise1" then
  177. x = reconstruct.image(photo_noise1_model, x)
  178. cleanup_model(photo_noise1_model)
  179. elseif options.method == "noise2" then
  180. x = reconstruct.image(photo_noise2_model, x)
  181. cleanup_model(photo_noise2_model)
  182. elseif options.method == "noise3" then
  183. x = reconstruct.image(photo_noise3_model, x)
  184. cleanup_model(photo_noise3_model)
  185. end
  186. end
  187. image_loader.save_png(cache_file, x)
  188. return x, alpha
  189. end
  190. end
  191. local function client_disconnected(handler)
  192. return not(handler.request and
  193. handler.request.connection and
  194. handler.request.connection.stream and
  195. (not handler.request.connection.stream:closed()))
  196. end
  197. local function make_output_filename(filename, mode)
  198. local e = path.extension(filename)
  199. local base = filename:sub(0, filename:len() - e:len())
  200. if mode then
  201. return base .. "_waifu2x_" .. mode .. ".png"
  202. else
  203. return base .. ".png"
  204. end
  205. end
  206. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  207. function APIHandler:post()
  208. if client_disconnected(self) then
  209. self:set_status(400)
  210. self:write("client disconnected")
  211. return
  212. end
  213. local x, alpha, blob, filename = get_image(self)
  214. local scale = tonumber(self:get_argument("scale", "0"))
  215. local noise = tonumber(self:get_argument("noise", "0"))
  216. local style = self:get_argument("style", "art")
  217. local download = (self:get_argument("download", "")):len()
  218. if style ~= "art" then
  219. style = "photo" -- style must be art or photo
  220. end
  221. if x and valid_size(x, scale) then
  222. local prefix = nil
  223. if (noise ~= 0 or scale ~= 0) then
  224. local hash = md5.sumhexa(blob)
  225. local alpha_prefix = style .. "_" .. hash .. "_alpha"
  226. local border = false
  227. if scale ~= 0 and alpha then
  228. border = true
  229. end
  230. if noise == 1 then
  231. prefix = style .. "_noise1_"
  232. x = convert(x, alpha, {method = "noise1", style = style,
  233. prefix = prefix .. hash,
  234. alpha_prefix = alpha_prefix, border = border})
  235. border = false
  236. elseif noise == 2 then
  237. prefix = style .. "_noise2_"
  238. x = convert(x, alpha, {method = "noise2", style = style,
  239. prefix = prefix .. hash,
  240. alpha_prefix = alpha_prefix, border = border})
  241. border = false
  242. elseif noise == 3 then
  243. prefix = style .. "_noise3_"
  244. x = convert(x, alpha, {method = "noise3", style = style,
  245. prefix = prefix .. hash,
  246. alpha_prefix = alpha_prefix, border = border})
  247. border = false
  248. end
  249. if scale == 1 or scale == 2 then
  250. if noise == 1 then
  251. prefix = style .. "_noise1_scale_"
  252. elseif noise == 2 then
  253. prefix = style .. "_noise2_scale_"
  254. elseif noise == 3 then
  255. prefix = style .. "_noise3_scale_"
  256. else
  257. prefix = style .. "_scale_"
  258. end
  259. x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
  260. if scale == 1 then
  261. x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
  262. end
  263. end
  264. end
  265. local name = nil
  266. if filename then
  267. if prefix then
  268. name = make_output_filename(filename, prefix:sub(0, prefix:len()-1))
  269. else
  270. name = make_output_filename(filename, nil)
  271. end
  272. else
  273. name = uuid() .. ".png"
  274. end
  275. local blob = image_loader.encode_png(alpha_util.composite(x, alpha), 8, true)
  276. self:set_header("Content-Length", string.format("%d", #blob))
  277. if download > 0 then
  278. self:set_header("Content-Type", "application/octet-stream")
  279. self:set_header("Content-Disposition", string.format('attachment; filename="%s"', name))
  280. else
  281. self:set_header("Content-Type", "image/png")
  282. self:set_header("Content-Disposition", string.format('inline; filename="%s"', name))
  283. end
  284. self:write(blob)
  285. else
  286. if not x then
  287. self:set_status(400)
  288. self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
  289. else
  290. self:set_status(400)
  291. self:write("ERROR: image size exceeds maximum allowable size.")
  292. end
  293. end
  294. collectgarbage()
  295. end
  296. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  297. local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
  298. local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
  299. local index_pt = file.read(path.join(ROOT, "assets", "index.pt.html"))
  300. local index_es = file.read(path.join(ROOT, "assets", "index.es.html"))
  301. local index_fr = file.read(path.join(ROOT, "assets", "index.fr.html"))
  302. local index_en = file.read(path.join(ROOT, "assets", "index.html"))
  303. function FormHandler:get()
  304. local lang = self.request.headers:get("Accept-Language")
  305. if lang then
  306. local langs = utils.split(lang, ",")
  307. for i = 1, #langs do
  308. langs[i] = utils.split(langs[i], ";")[1]
  309. end
  310. if langs[1] == "ja" then
  311. self:write(index_ja)
  312. elseif langs[1] == "ru" then
  313. self:write(index_ru)
  314. elseif langs[1] == "pt" or langs[1] == "pt-BR" then
  315. self:write(index_pt)
  316. elseif langs[1] == "es" or langs[1] == "es-ES" then
  317. self:write(index_es)
  318. elseif langs[1] == "fr" then
  319. self:write(index_fr)
  320. else
  321. self:write(index_en)
  322. end
  323. else
  324. self:write(index_en)
  325. end
  326. end
  327. turbo.log.categories = {
  328. ["success"] = true,
  329. ["notice"] = false,
  330. ["warning"] = true,
  331. ["error"] = true,
  332. ["debug"] = false,
  333. ["development"] = false
  334. }
  335. local app = turbo.web.Application:new(
  336. {
  337. {"^/$", FormHandler},
  338. {"^/api$", APIHandler},
  339. {"^/([%a%d%.%-_]+)$", turbo.web.StaticFileHandler, path.join(ROOT, "assets/")},
  340. }
  341. )
  342. app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  343. turbo.ioloop.instance():start()