web.lua 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. local ROOT = path.dirname(__FILE__)
  4. package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
  5. _G.TURBO_SSL = true
  6. require 'w2nn'
  7. local uuid = require 'uuid'
  8. local ffi = require 'ffi'
  9. local md5 = require 'md5'
  10. local iproc = require 'iproc'
  11. local reconstruct = require 'reconstruct'
  12. local image_loader = require 'image_loader'
  13. local alpha_util = require 'alpha_util'
  14. local gm = require 'graphicsmagick'
  15. -- Note: turbo and xlua has different implementation of string:split().
  16. -- Therefore, string:split() has conflict issue.
  17. -- In this script, use turbo's string:split().
  18. local turbo = require 'turbo'
  19. local cmd = torch.CmdLine()
  20. cmd:text()
  21. cmd:text("waifu2x-api")
  22. cmd:text("Options:")
  23. cmd:option("-port", 8812, 'listen port')
  24. cmd:option("-gpu", 1, 'Device ID')
  25. cmd:option("-crop_size", 128, 'patch size per process')
  26. cmd:option("-batch_size", 1, 'batch size')
  27. cmd:option("-thread", -1, 'number of CPU threads')
  28. cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)')
  29. local opt = cmd:parse(arg)
  30. cutorch.setDevice(opt.gpu)
  31. torch.setdefaulttensortype('torch.FloatTensor')
  32. if opt.thread > 0 then
  33. torch.setnumthreads(opt.thread)
  34. end
  35. if cudnn then
  36. cudnn.fastest = true
  37. cudnn.benchmark = true
  38. end
  39. opt.force_cudnn = opt.force_cudnn == 1
  40. local ART_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "art")
  41. local PHOTO_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "photo")
  42. local art_model = {
  43. scale = w2nn.load_model(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), opt.force_cudnn),
  44. noise0_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise0_scale2.0x_model.t7"), opt.force_cudnn),
  45. noise1_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise1_scale2.0x_model.t7"), opt.force_cudnn),
  46. noise2_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise2_scale2.0x_model.t7"), opt.force_cudnn),
  47. noise3_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise3_scale2.0x_model.t7"), opt.force_cudnn),
  48. noise0 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise0_model.t7"), opt.force_cudnn),
  49. noise1 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise1_model.t7"), opt.force_cudnn),
  50. noise2 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise2_model.t7"), opt.force_cudnn),
  51. noise3 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise3_model.t7"), opt.force_cudnn)
  52. }
  53. local photo_model = {
  54. scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), opt.force_cudnn),
  55. noise0_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise0_scale2.0x_model.t7"), opt.force_cudnn),
  56. noise1_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise1_scale2.0x_model.t7"), opt.force_cudnn),
  57. noise2_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise2_scale2.0x_model.t7"), opt.force_cudnn),
  58. noise3_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise3_scale2.0x_model.t7"), opt.force_cudnn),
  59. noise0 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise0_model.t7"), opt.force_cudnn),
  60. noise1 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), opt.force_cudnn),
  61. noise2 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), opt.force_cudnn),
  62. noise3 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise3_model.t7"), opt.force_cudnn)
  63. }
  64. collectgarbage()
  65. local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
  66. local CACHE_DIR = path.join(ROOT, "cache")
  67. local MAX_NOISE_IMAGE = 3000 * 3000
  68. local MAX_SCALE_IMAGE = 1500 * 1500
  69. local CURL_OPTIONS = {
  70. request_timeout = 60,
  71. connect_timeout = 60,
  72. allow_redirects = true,
  73. max_redirects = 2
  74. }
  75. local CURL_MAX_SIZE = 5 * 1024 * 1024
  76. local function valid_size(x, scale, tta_level)
  77. if scale <= 0 then
  78. local limit = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / tta_level, 0.5)), 2)
  79. return x:size(2) * x:size(3) <= limit
  80. else
  81. local limit = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / tta_level, 0.5)), 2)
  82. return x:size(2) * x:size(3) <= limit
  83. end
  84. end
  85. local function auto_tta_level(x, scale)
  86. local limit2, limit4, limit8
  87. if scale <= 0 then
  88. limit2 = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / 2, 0.5)), 2)
  89. limit4 = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / 4, 0.5)), 2)
  90. limit8 = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / 8, 0.5)), 2)
  91. else
  92. limit2 = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / 2, 0.5)), 2)
  93. limit4 = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / 4, 0.5)), 2)
  94. limit8 = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / 8, 0.5)), 2)
  95. end
  96. local px = x:size(2) * x:size(3)
  97. if px <= limit8 then
  98. return 8
  99. elseif px <= limit4 then
  100. return 4
  101. elseif px <= limit2 then
  102. return 2
  103. else
  104. return 1
  105. end
  106. end
  107. local function cache_url(url)
  108. local hash = md5.sumhexa(url)
  109. local cache_file = path.join(CACHE_DIR, "url_" .. hash)
  110. if path.exists(cache_file) then
  111. return image_loader.load_float(cache_file)
  112. else
  113. local res = coroutine.yield(
  114. turbo.async.HTTPClient({verify_ca=false},
  115. nil,
  116. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  117. )
  118. if res.code == 200 then
  119. local content_type = res.headers:get("Content-Type", true)
  120. if type(content_type) == "table" then
  121. content_type = content_type[1]
  122. end
  123. if content_type and content_type:find("image") then
  124. local fp = io.open(cache_file, "wb")
  125. local blob = res.body
  126. fp:write(blob)
  127. fp:close()
  128. return image_loader.decode_float(blob)
  129. end
  130. end
  131. end
  132. return nil, nil
  133. end
  134. local function get_image(req)
  135. local file_info = req:get_arguments("file")
  136. local url = req:get_argument("url", "")
  137. local file = nil
  138. local filename = nil
  139. if file_info and #file_info == 1 then
  140. file = file_info[1][1]
  141. local disp = file_info[1]["content-disposition"]
  142. if disp and disp["filename"] then
  143. filename = path.basename(disp["filename"])
  144. end
  145. end
  146. if file and file:len() > 0 then
  147. local x, meta = image_loader.decode_float(file)
  148. return x, meta, filename
  149. elseif url and url:len() > 0 then
  150. local x, meta = cache_url(url)
  151. return x, meta, filename
  152. end
  153. return nil, nil, nil
  154. end
  155. local function cleanup_model(model)
  156. if CLEANUP_MODEL then
  157. model:clearState() -- release GPU memory
  158. end
  159. end
  160. local function convert(x, meta, options)
  161. local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
  162. local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
  163. local alpha = meta.alpha
  164. local alpha_orig = alpha
  165. if path.exists(alpha_cache_file) then
  166. alpha = image_loader.load_float(alpha_cache_file)
  167. if alpha:dim() == 2 then
  168. alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
  169. end
  170. if alpha:size(1) == 3 then
  171. alpha = image.rgb2y(alpha)
  172. end
  173. end
  174. if path.exists(cache_file) then
  175. x = image_loader.load_float(cache_file)
  176. meta = tablex.copy(meta)
  177. meta.alpha = alpha
  178. return x, meta
  179. else
  180. local model = nil
  181. if options.style == "art" then
  182. model = art_model
  183. elseif options.style == "photo" then
  184. model = photo_model
  185. end
  186. if options.border then
  187. x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(model.scale))
  188. end
  189. if (options.method == "scale" or
  190. options.method == "noise0_scale" or
  191. options.method == "noise1_scale" or
  192. options.method == "noise2_scale" or
  193. options.method == "noise3_scale")
  194. then
  195. x = reconstruct.scale_tta(model[options.method], options.tta_level, 2.0, x,
  196. opt.crop_size, opt.batch_size)
  197. if alpha then
  198. if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
  199. alpha = reconstruct.scale(model.scale, 2.0, alpha,
  200. opt.crop_size, opt.batch_size)
  201. image_loader.save_png(alpha_cache_file, alpha)
  202. cleanup_model(model.scale)
  203. end
  204. end
  205. cleanup_model(model[options.method])
  206. elseif (options.method == "noise0" or
  207. options.method == "noise1" or
  208. options.method == "noise2" or
  209. options.method == "noise3")
  210. then
  211. x = reconstruct.image_tta(model[options.method], options.tta_level,
  212. x, opt.crop_size, opt.batch_size)
  213. cleanup_model(model[options.method])
  214. end
  215. image_loader.save_png(cache_file, x)
  216. meta = tablex.copy(meta)
  217. meta.alpha = alpha
  218. return x, meta
  219. end
  220. end
  221. local function client_disconnected(handler)
  222. return not(handler.request and
  223. handler.request.connection and
  224. handler.request.connection.stream and
  225. (not handler.request.connection.stream:closed()))
  226. end
  227. local function make_output_filename(filename, mode)
  228. local e = path.extension(filename)
  229. local base = filename:sub(0, filename:len() - e:len())
  230. if mode then
  231. return base .. "_waifu2x_" .. mode .. ".png"
  232. else
  233. return base .. ".png"
  234. end
  235. end
  236. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  237. function APIHandler:post()
  238. if client_disconnected(self) then
  239. self:set_status(400)
  240. self:write("client disconnected")
  241. return
  242. end
  243. local x, meta, filename = get_image(self)
  244. local scale = tonumber(self:get_argument("scale", "-1"))
  245. local noise = tonumber(self:get_argument("noise", "-1"))
  246. local tta_level = tonumber(self:get_argument("tta_level", "1"))
  247. local style = self:get_argument("style", "art")
  248. local download = (self:get_argument("download", "")):len()
  249. if tta_level == 0 then
  250. tta_level = auto_tta_level(x, scale)
  251. end
  252. if not (tta_level == 0 or tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
  253. tta_level = 1
  254. end
  255. if style ~= "art" then
  256. style = "photo" -- style must be art or photo
  257. end
  258. if x and valid_size(x, scale, tta_level) then
  259. local prefix = nil
  260. if (noise >= 0 or scale > 0) then
  261. local hash = md5.sumhexa(meta.blob)
  262. local alpha_prefix = style .. "_" .. hash .. "_alpha"
  263. local border = false
  264. if scale >= 0 and meta.alpha then
  265. border = true
  266. end
  267. if (scale == 1 or scale == 2) and (noise < 0) then
  268. prefix = style .. "_scale_tta_" .. tta_level .. "_"
  269. x, meta = convert(x, meta, {method = "scale",
  270. style = style,
  271. tta_level = tta_level,
  272. prefix = prefix .. hash,
  273. alpha_prefix = alpha_prefix,
  274. border = border})
  275. if scale == 1 then
  276. x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
  277. end
  278. elseif (scale == 1 or scale == 2) and (noise == 0 or noise == 1 or noise == 2 or noise == 3) then
  279. prefix = style .. string.format("_noise%d_scale_tta_", noise) .. tta_level .. "_"
  280. x, meta = convert(x, meta, {method = string.format("noise%d_scale", noise),
  281. style = style,
  282. tta_level = tta_level,
  283. prefix = prefix .. hash,
  284. alpha_prefix = alpha_prefix,
  285. border = border})
  286. if scale == 1 then
  287. x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
  288. end
  289. elseif (noise == 0 or noise == 1 or noise == 2 or noise == 3) then
  290. prefix = style .. string.format("_noise%d_tta_", noise) .. tta_level .. "_"
  291. x = convert(x, meta, {method = string.format("noise%d", noise),
  292. style = style,
  293. tta_level = tta_level,
  294. prefix = prefix .. hash,
  295. alpha_prefix = alpha_prefix,
  296. border = border})
  297. border = false
  298. end
  299. end
  300. local name = nil
  301. if filename then
  302. if prefix then
  303. name = make_output_filename(filename, prefix:sub(0, prefix:len()-1))
  304. else
  305. name = make_output_filename(filename, nil)
  306. end
  307. else
  308. name = uuid() .. ".png"
  309. end
  310. local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
  311. tablex.update({depth = 8, inplace = true}, meta))
  312. self:set_header("Content-Length", string.format("%d", #blob))
  313. if download > 0 then
  314. self:set_header("Content-Type", "application/octet-stream")
  315. self:set_header("Content-Disposition", string.format('attachment; filename="%s"', name))
  316. else
  317. self:set_header("Content-Type", "image/png")
  318. self:set_header("Content-Disposition", string.format('inline; filename="%s"', name))
  319. end
  320. self:write(blob)
  321. else
  322. if not x then
  323. self:set_status(400)
  324. self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
  325. else
  326. self:set_status(400)
  327. self:write("ERROR: image size exceeds maximum allowable size.")
  328. end
  329. end
  330. collectgarbage()
  331. end
  332. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  333. local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
  334. local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
  335. local index_pt = file.read(path.join(ROOT, "assets", "index.pt.html"))
  336. local index_es = file.read(path.join(ROOT, "assets", "index.es.html"))
  337. local index_fr = file.read(path.join(ROOT, "assets", "index.fr.html"))
  338. local index_de = file.read(path.join(ROOT, "assets", "index.de.html"))
  339. local index_tr = file.read(path.join(ROOT, "assets", "index.tr.html"))
  340. local index_zh_cn = file.read(path.join(ROOT, "assets", "index.zh-CN.html"))
  341. local index_ko = file.read(path.join(ROOT, "assets", "index.ko.html"))
  342. local index_en = file.read(path.join(ROOT, "assets", "index.html"))
  343. function FormHandler:get()
  344. local lang = self.request.headers:get("Accept-Language")
  345. if lang then
  346. local langs = utils.split(lang, ",")
  347. for i = 1, #langs do
  348. langs[i] = utils.split(langs[i], ";")[1]
  349. end
  350. if langs[1] == "ja" then
  351. self:write(index_ja)
  352. elseif langs[1] == "ru" then
  353. self:write(index_ru)
  354. elseif langs[1] == "pt" or langs[1] == "pt-BR" then
  355. self:write(index_pt)
  356. elseif langs[1] == "es" or langs[1] == "es-ES" then
  357. self:write(index_es)
  358. elseif langs[1] == "fr" then
  359. self:write(index_fr)
  360. elseif langs[1] == "de" then
  361. self:write(index_de)
  362. elseif langs[1] == "tr" then
  363. self:write(index_tr)
  364. elseif langs[1] == "zh-CN" or langs[1] == "zh" then
  365. self:write(index_zh_cn)
  366. elseif langs[1] == "ko" then
  367. self:write(index_ko)
  368. else
  369. self:write(index_en)
  370. end
  371. else
  372. self:write(index_en)
  373. end
  374. end
  375. turbo.log.categories = {
  376. ["success"] = true,
  377. ["notice"] = false,
  378. ["warning"] = true,
  379. ["error"] = true,
  380. ["debug"] = false,
  381. ["development"] = false
  382. }
  383. local app = turbo.web.Application:new(
  384. {
  385. {"^/$", FormHandler},
  386. {"^/api$", APIHandler},
  387. {"^/([%a%d%.%-_]+)$", turbo.web.StaticFileHandler, path.join(ROOT, "assets/")},
  388. }
  389. )
  390. app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  391. turbo.ioloop.instance():start()