web.lua 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. local ROOT = path.dirname(__FILE__)
  4. package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
  5. _G.TURBO_SSL = true
  6. require 'w2nn'
  7. local uuid = require 'uuid'
  8. local ffi = require 'ffi'
  9. local md5 = require 'md5'
  10. local iproc = require 'iproc'
  11. local reconstruct = require 'reconstruct'
  12. local image_loader = require 'image_loader'
  13. local alpha_util = require 'alpha_util'
  14. local gm = require 'graphicsmagick'
  15. -- Note: turbo and xlua has different implementation of string:split().
  16. -- Therefore, string:split() has conflict issue.
  17. -- In this script, use turbo's string:split().
  18. local turbo = require 'turbo'
  19. local cmd = torch.CmdLine()
  20. cmd:text()
  21. cmd:text("waifu2x-api")
  22. cmd:text("Options:")
  23. cmd:option("-port", 8812, 'listen port')
  24. cmd:option("-gpu", 1, 'Device ID')
  25. cmd:option("-crop_size", 128, 'patch size per process')
  26. cmd:option("-batch_size", 1, 'batch size')
  27. cmd:option("-thread", -1, 'number of CPU threads')
  28. cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)')
  29. local opt = cmd:parse(arg)
  30. cutorch.setDevice(opt.gpu)
  31. torch.setdefaulttensortype('torch.FloatTensor')
  32. if opt.thread > 0 then
  33. torch.setnumthreads(opt.thread)
  34. end
  35. if cudnn then
  36. cudnn.fastest = true
  37. cudnn.benchmark = true
  38. end
  39. opt.force_cudnn = opt.force_cudnn == 1
  40. local ART_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "art")
  41. local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
  42. local art_scale2_model = w2nn.load_model(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), opt.force_cudnn)
  43. local art_noise1_model = w2nn.load_model(path.join(ART_MODEL_DIR, "noise1_model.t7"), opt.force_cudnn)
  44. local art_noise2_model = w2nn.load_model(path.join(ART_MODEL_DIR, "noise2_model.t7"), opt.force_cudnn)
  45. local art_noise3_model = w2nn.load_model(path.join(ART_MODEL_DIR, "noise3_model.t7"), opt.force_cudnn)
  46. local photo_scale2_model = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), opt.force_cudnn)
  47. local photo_noise1_model = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), opt.force_cudnn)
  48. local photo_noise2_model = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), opt.force_cudnn)
  49. local photo_noise3_model = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise3_model.t7"), opt.force_cudnn)
  50. collectgarbage()
  51. local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
  52. local CACHE_DIR = path.join(ROOT, "cache")
  53. local MAX_NOISE_IMAGE = 2560 * 2560
  54. local MAX_SCALE_IMAGE = 1280 * 1280
  55. local CURL_OPTIONS = {
  56. request_timeout = 60,
  57. connect_timeout = 60,
  58. allow_redirects = true,
  59. max_redirects = 2
  60. }
  61. local CURL_MAX_SIZE = 3 * 1024 * 1024
  62. local TTA_SUPPORT = false
  63. local function valid_size(x, scale)
  64. if scale == 0 then
  65. return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE
  66. else
  67. return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE
  68. end
  69. end
  70. local function cache_url(url)
  71. local hash = md5.sumhexa(url)
  72. local cache_file = path.join(CACHE_DIR, "url_" .. hash)
  73. if path.exists(cache_file) then
  74. return image_loader.load_float(cache_file)
  75. else
  76. local res = coroutine.yield(
  77. turbo.async.HTTPClient({verify_ca=false},
  78. nil,
  79. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  80. )
  81. if res.code == 200 then
  82. local content_type = res.headers:get("Content-Type", true)
  83. if type(content_type) == "table" then
  84. content_type = content_type[1]
  85. end
  86. if content_type and content_type:find("image") then
  87. local fp = io.open(cache_file, "wb")
  88. local blob = res.body
  89. fp:write(blob)
  90. fp:close()
  91. return image_loader.decode_float(blob)
  92. end
  93. end
  94. end
  95. return nil, nil
  96. end
  97. local function get_image(req)
  98. local file_info = req:get_arguments("file")
  99. local url = req:get_argument("url", "")
  100. local file = nil
  101. local filename = nil
  102. if file_info and #file_info == 1 then
  103. file = file_info[1][1]
  104. local disp = file_info[1]["content-disposition"]
  105. if disp and disp["filename"] then
  106. filename = path.basename(disp["filename"])
  107. end
  108. end
  109. if file and file:len() > 0 then
  110. local x, meta = image_loader.decode_float(file)
  111. return x, meta, filename
  112. elseif url and url:len() > 0 then
  113. local x, meta = cache_url(url)
  114. return x, meta, filename
  115. end
  116. return nil, nil, nil
  117. end
  118. local function cleanup_model(model)
  119. if CLEANUP_MODEL then
  120. model:clearState() -- release GPU memory
  121. end
  122. end
  123. local function convert(x, meta, options)
  124. local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
  125. local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
  126. local alpha = meta.alpha
  127. local alpha_orig = alpha
  128. if path.exists(alpha_cache_file) then
  129. alpha = image_loader.load_float(alpha_cache_file)
  130. if alpha:dim() == 2 then
  131. alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
  132. end
  133. if alpha:size(1) == 3 then
  134. alpha = image.rgb2y(alpha)
  135. end
  136. end
  137. if path.exists(cache_file) then
  138. x = image_loader.load_float(cache_file)
  139. meta = tablex.copy(meta)
  140. meta.alpha = alpha
  141. return x, meta
  142. else
  143. if options.style == "art" then
  144. if options.border then
  145. x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
  146. end
  147. if options.method == "scale" then
  148. x = reconstruct.scale_tta(art_scale2_model, options.tta_level, 2.0, x,
  149. opt.crop_size, opt.batch_size)
  150. if alpha then
  151. if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
  152. alpha = reconstruct.scale(art_scale2_model, 2.0, alpha,
  153. opt.crop_size, opt.batch_size)
  154. image_loader.save_png(alpha_cache_file, alpha)
  155. end
  156. end
  157. cleanup_model(art_scale2_model)
  158. elseif options.method == "noise1" then
  159. x = reconstruct.image_tta(art_noise1_model, options.tta_level,
  160. x, opt.crop_size, opt.batch_size)
  161. cleanup_model(art_noise1_model)
  162. elseif options.method == "noise2" then
  163. x = reconstruct.image_tta(art_noise2_model, options.tta_level,
  164. x, opt.crop_size, opt.batch_size)
  165. cleanup_model(art_noise2_model)
  166. elseif options.method == "noise3" then
  167. x = reconstruct.image_tta(art_noise3_model, options.tta_level,
  168. x, opt.crop_size, opt.batch_size)
  169. cleanup_model(art_noise3_model)
  170. end
  171. else -- photo
  172. if options.border then
  173. x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
  174. end
  175. if options.method == "scale" then
  176. x = reconstruct.scale_tta(photo_scale2_model, options.tta_level, 2.0, x,
  177. opt.crop_size, opt.batch_size)
  178. if alpha then
  179. if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
  180. alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha,
  181. opt.crop_size, opt.batch_size)
  182. image_loader.save_png(alpha_cache_file, alpha)
  183. end
  184. end
  185. cleanup_model(photo_scale2_model)
  186. elseif options.method == "noise1" then
  187. x = reconstruct.image_tta(photo_noise1_model, options.tta_level,
  188. x, opt.crop_size, opt.batch_size)
  189. cleanup_model(photo_noise1_model)
  190. elseif options.method == "noise2" then
  191. x = reconstruct.image_tta(photo_noise2_model, options.tta_level,
  192. x, opt.crop_size, opt.batch_size)
  193. cleanup_model(photo_noise2_model)
  194. elseif options.method == "noise3" then
  195. x = reconstruct.image_tta(photo_noise3_model, options.tta_level,
  196. x, opt.crop_size, opt.batch_size)
  197. cleanup_model(photo_noise3_model)
  198. end
  199. end
  200. image_loader.save_png(cache_file, x)
  201. meta = tablex.copy(meta)
  202. meta.alpha = alpha
  203. return x, meta
  204. end
  205. end
  206. local function client_disconnected(handler)
  207. return not(handler.request and
  208. handler.request.connection and
  209. handler.request.connection.stream and
  210. (not handler.request.connection.stream:closed()))
  211. end
  212. local function make_output_filename(filename, mode)
  213. local e = path.extension(filename)
  214. local base = filename:sub(0, filename:len() - e:len())
  215. if mode then
  216. return base .. "_waifu2x_" .. mode .. ".png"
  217. else
  218. return base .. ".png"
  219. end
  220. end
  221. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  222. function APIHandler:post()
  223. if client_disconnected(self) then
  224. self:set_status(400)
  225. self:write("client disconnected")
  226. return
  227. end
  228. local x, meta, filename = get_image(self)
  229. local scale = tonumber(self:get_argument("scale", "0"))
  230. local noise = tonumber(self:get_argument("noise", "0"))
  231. local tta_level = tonumber(self:get_argument("noise", "1"))
  232. local style = self:get_argument("style", "art")
  233. local download = (self:get_argument("download", "")):len()
  234. if not TTA_SUPPORT then
  235. tta_level = 1 -- disable TTA mode
  236. else
  237. if not (tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
  238. tta_level = 1
  239. end
  240. end
  241. if style ~= "art" then
  242. style = "photo" -- style must be art or photo
  243. end
  244. if x and valid_size(x, scale) then
  245. local prefix = nil
  246. if (noise ~= 0 or scale ~= 0) then
  247. local hash = md5.sumhexa(meta.blob)
  248. local alpha_prefix = style .. "_" .. hash .. "_alpha"
  249. local border = false
  250. if scale ~= 0 and meta.alpha then
  251. border = true
  252. end
  253. if noise == 1 then
  254. prefix = style .. "_noise1_tta_" .. tta_level .. "_"
  255. x = convert(x, meta, {method = "noise1", style = style, tta_level = tta_level,
  256. prefix = prefix .. hash,
  257. alpha_prefix = alpha_prefix, border = border})
  258. border = false
  259. elseif noise == 2 then
  260. prefix = style .. "_noise2_tta_" .. tta_level .. "_"
  261. x = convert(x, meta, {method = "noise2", style = style, tta_level = tta_level,
  262. prefix = prefix .. hash,
  263. alpha_prefix = alpha_prefix, border = border})
  264. border = false
  265. elseif noise == 3 then
  266. prefix = style .. "_noise3_tta_" .. tta_level .. "_"
  267. x = convert(x, meta, {method = "noise3", style = style, tta_level = tta_level,
  268. prefix = prefix .. hash,
  269. alpha_prefix = alpha_prefix, border = border})
  270. border = false
  271. end
  272. if scale == 1 or scale == 2 then
  273. if noise == 1 then
  274. prefix = style .. "_noise1_scale_tta_" .. tta_level .. "_"
  275. elseif noise == 2 then
  276. prefix = style .. "_noise2_scale_tta_" .. tta_level .. "_"
  277. elseif noise == 3 then
  278. prefix = style .. "_noise3_scale_tta_" .. tta_level .. "_"
  279. else
  280. prefix = style .. "_scale_tta_" .. tta_level .. "_"
  281. end
  282. x, meta = convert(x, meta, {method = "scale", style = style, tta_level = tta_level,
  283. prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
  284. if scale == 1 then
  285. x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
  286. end
  287. end
  288. end
  289. local name = nil
  290. if filename then
  291. if prefix then
  292. name = make_output_filename(filename, prefix:sub(0, prefix:len()-1))
  293. else
  294. name = make_output_filename(filename, nil)
  295. end
  296. else
  297. name = uuid() .. ".png"
  298. end
  299. local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
  300. tablex.update({depth = 8, inplace = true}, meta))
  301. self:set_header("Content-Length", string.format("%d", #blob))
  302. if download > 0 then
  303. self:set_header("Content-Type", "application/octet-stream")
  304. self:set_header("Content-Disposition", string.format('attachment; filename="%s"', name))
  305. else
  306. self:set_header("Content-Type", "image/png")
  307. self:set_header("Content-Disposition", string.format('inline; filename="%s"', name))
  308. end
  309. self:write(blob)
  310. else
  311. if not x then
  312. self:set_status(400)
  313. self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
  314. else
  315. self:set_status(400)
  316. self:write("ERROR: image size exceeds maximum allowable size.")
  317. end
  318. end
  319. collectgarbage()
  320. end
  321. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  322. local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
  323. local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
  324. local index_pt = file.read(path.join(ROOT, "assets", "index.pt.html"))
  325. local index_es = file.read(path.join(ROOT, "assets", "index.es.html"))
  326. local index_fr = file.read(path.join(ROOT, "assets", "index.fr.html"))
  327. local index_de = file.read(path.join(ROOT, "assets", "index.de.html"))
  328. local index_tr = file.read(path.join(ROOT, "assets", "index.tr.html"))
  329. local index_zh_cn = file.read(path.join(ROOT, "assets", "index.zh-CN.html"))
  330. local index_en = file.read(path.join(ROOT, "assets", "index.html"))
  331. function FormHandler:get()
  332. local lang = self.request.headers:get("Accept-Language")
  333. if lang then
  334. local langs = utils.split(lang, ",")
  335. for i = 1, #langs do
  336. langs[i] = utils.split(langs[i], ";")[1]
  337. end
  338. if langs[1] == "ja" then
  339. self:write(index_ja)
  340. elseif langs[1] == "ru" then
  341. self:write(index_ru)
  342. elseif langs[1] == "pt" or langs[1] == "pt-BR" then
  343. self:write(index_pt)
  344. elseif langs[1] == "es" or langs[1] == "es-ES" then
  345. self:write(index_es)
  346. elseif langs[1] == "fr" then
  347. self:write(index_fr)
  348. elseif langs[1] == "de" then
  349. self:write(index_de)
  350. elseif langs[1] == "tr" then
  351. self:write(index_tr)
  352. elseif langs[1] == "zh-CN" or langs[1] == "zh" then
  353. self:write(index_zh_cn)
  354. else
  355. self:write(index_en)
  356. end
  357. else
  358. self:write(index_en)
  359. end
  360. end
  361. turbo.log.categories = {
  362. ["success"] = true,
  363. ["notice"] = false,
  364. ["warning"] = true,
  365. ["error"] = true,
  366. ["debug"] = false,
  367. ["development"] = false
  368. }
  369. local app = turbo.web.Application:new(
  370. {
  371. {"^/$", FormHandler},
  372. {"^/api$", APIHandler},
  373. {"^/([%a%d%.%-_]+)$", turbo.web.StaticFileHandler, path.join(ROOT, "assets/")},
  374. }
  375. )
  376. app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  377. turbo.ioloop.instance():start()