web.lua 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. local ROOT = path.dirname(__FILE__)
  4. package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
  5. _G.TURBO_SSL = true
  6. require 'w2nn'
  7. local uuid = require 'uuid'
  8. local ffi = require 'ffi'
  9. local md5 = require 'md5'
  10. local iproc = require 'iproc'
  11. local reconstruct = require 'reconstruct'
  12. local image_loader = require 'image_loader'
  13. local alpha_util = require 'alpha_util'
  14. local compression = require 'compression'
  15. local gm = require 'graphicsmagick'
  16. -- Note: turbo and xlua has different implementation of string:split().
  17. -- Therefore, string:split() has conflict issue.
  18. -- In this script, use turbo's string:split().
  19. local turbo = require 'turbo'
  20. local cmd = torch.CmdLine()
  21. cmd:text()
  22. cmd:text("waifu2x-api")
  23. cmd:text("Options:")
  24. cmd:option("-port", 8812, 'listen port')
  25. cmd:option("-gpu", 1, 'Device ID')
  26. cmd:option("-enable_tta", 0, 'enable TTA query(0|1)')
  27. cmd:option("-crop_size", 256, 'patch size per process')
  28. cmd:option("-batch_size", 1, 'batch size')
  29. cmd:option("-thread", -1, 'number of CPU threads')
  30. cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)')
  31. cmd:option("-max_pixels", 3000 * 3000, 'maximum number of output image pixels (e.g. 3000x3000=9000000)')
  32. cmd:option("-curl_request_timeout", 60, "request_timeout for curl")
  33. cmd:option("-curl_connect_timeout", 60, "connect_timeout for curl")
  34. cmd:option("-curl_max_redirects", 2, "max_redirects for curl")
  35. cmd:option("-max_body_size", 5 * 1024 * 1024, "maximum allowed size for uploaded files")
  36. cmd:option("-cache_max", 200, "number of cached images on RAM")
  37. local opt = cmd:parse(arg)
  38. cutorch.setDevice(opt.gpu)
  39. torch.setdefaulttensortype('torch.FloatTensor')
  40. if opt.thread > 0 then
  41. torch.setnumthreads(opt.thread)
  42. end
  43. if cudnn then
  44. cudnn.fastest = true
  45. cudnn.benchmark = true
  46. end
  47. opt.force_cudnn = opt.force_cudnn == 1
  48. opt.enable_tta = opt.enable_tta == 1
  49. --local ART_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "art")
  50. local ART_MODEL_DIR = path.join(ROOT, "models", "cunet", "art")
  51. local PHOTO_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "photo")
  52. local art_model = {
  53. scale = w2nn.load_model(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), opt.force_cudnn),
  54. noise0_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise0_scale2.0x_model.t7"), opt.force_cudnn),
  55. noise1_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise1_scale2.0x_model.t7"), opt.force_cudnn),
  56. noise2_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise2_scale2.0x_model.t7"), opt.force_cudnn),
  57. noise3_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise3_scale2.0x_model.t7"), opt.force_cudnn),
  58. noise0 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise0_model.t7"), opt.force_cudnn),
  59. noise1 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise1_model.t7"), opt.force_cudnn),
  60. noise2 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise2_model.t7"), opt.force_cudnn),
  61. noise3 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise3_model.t7"), opt.force_cudnn)
  62. }
  63. local photo_model = {
  64. scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), opt.force_cudnn),
  65. noise0_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise0_scale2.0x_model.t7"), opt.force_cudnn),
  66. noise1_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise1_scale2.0x_model.t7"), opt.force_cudnn),
  67. noise2_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise2_scale2.0x_model.t7"), opt.force_cudnn),
  68. noise3_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise3_scale2.0x_model.t7"), opt.force_cudnn),
  69. noise0 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise0_model.t7"), opt.force_cudnn),
  70. noise1 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), opt.force_cudnn),
  71. noise2 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), opt.force_cudnn),
  72. noise3 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise3_model.t7"), opt.force_cudnn)
  73. }
  74. collectgarbage()
  75. local CLEANUP_MODEL = true -- if you are using the low memory GPU, you could use this flag.
  76. local CACHE_DIR = path.join(ROOT, "cache")
  77. local MAX_NOISE_IMAGE = opt.max_pixels
  78. local MAX_SCALE_IMAGE = (math.sqrt(opt.max_pixels) / 2)^2
  79. local PNG_DEPTH = 8
  80. local CURL_OPTIONS = {
  81. request_timeout = opt.curl_request_timeout,
  82. connect_timeout = opt.curl_connect_timeout,
  83. allow_redirects = true,
  84. max_redirects = opt.curl_max_redirects
  85. }
  86. local CURL_MAX_SIZE = opt.max_body_size
  87. local function valid_size(x, scale, tta_level)
  88. if scale <= 0 then
  89. local limit = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / tta_level, 0.5)), 2)
  90. return x:size(2) * x:size(3) <= limit
  91. else
  92. local limit = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / tta_level, 0.5)), 2)
  93. return x:size(2) * x:size(3) <= limit
  94. end
  95. end
  96. local function auto_tta_level(x, scale)
  97. local limit2, limit4, limit8
  98. if scale <= 0 then
  99. limit2 = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / 2, 0.5)), 2)
  100. limit4 = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / 4, 0.5)), 2)
  101. limit8 = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / 8, 0.5)), 2)
  102. else
  103. limit2 = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / 2, 0.5)), 2)
  104. limit4 = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / 4, 0.5)), 2)
  105. limit8 = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / 8, 0.5)), 2)
  106. end
  107. local px = x:size(2) * x:size(3)
  108. if px <= limit8 then
  109. return 8
  110. elseif px <= limit4 then
  111. return 4
  112. elseif px <= limit2 then
  113. return 2
  114. else
  115. return 1
  116. end
  117. end
  118. local function cache_url(url)
  119. local hash = md5.sumhexa(url)
  120. local cache_file = path.join(CACHE_DIR, "url_" .. hash)
  121. if path.exists(cache_file) then
  122. return image_loader.load_float(cache_file)
  123. else
  124. local res = coroutine.yield(
  125. turbo.async.HTTPClient({verify_ca=false},
  126. nil,
  127. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  128. )
  129. if res.code == 200 then
  130. local content_type = res.headers:get("Content-Type", true)
  131. if type(content_type) == "table" then
  132. content_type = content_type[1]
  133. end
  134. if content_type and content_type:find("image") then
  135. local fp = io.open(cache_file, "wb")
  136. local blob = res.body
  137. fp:write(blob)
  138. fp:close()
  139. return image_loader.decode_float(blob)
  140. end
  141. end
  142. end
  143. return nil, nil
  144. end
  145. local function get_image(req)
  146. local file_info = req:get_arguments("file")
  147. local url = req:get_argument("url", "")
  148. local file = nil
  149. local filename = nil
  150. if file_info and #file_info == 1 then
  151. file = file_info[1][1]
  152. local disp = file_info[1]["content-disposition"]
  153. if disp and disp["filename"] then
  154. filename = path.basename(disp["filename"])
  155. end
  156. end
  157. if file and file:len() > 0 then
  158. local x, meta = image_loader.decode_float(file)
  159. return x, meta, filename
  160. elseif url and url:len() > 0 then
  161. local x, meta = cache_url(url)
  162. return x, meta, filename
  163. end
  164. return nil, nil, nil
  165. end
  166. local function cleanup_model(model)
  167. if CLEANUP_MODEL then
  168. model:clearState() -- release GPU memory
  169. end
  170. end
  171. -- cache
  172. local g_cache = {}
  173. local function cache_count()
  174. local count = 0
  175. for _ in pairs(g_cache) do
  176. count = count + 1
  177. end
  178. return count
  179. end
  180. local function cache_remove_old()
  181. local old_time = nil
  182. local old_key = nil
  183. for k, v in pairs(g_cache) do
  184. if old_time == nil or old_time > v.updated_at then
  185. old_key = k
  186. old_time = v.updated_at
  187. end
  188. end
  189. if old_key then
  190. g_cache[old_key] = nil
  191. end
  192. end
  193. local function cache_compress(raw_image)
  194. if raw_image then
  195. compressed_image = compression.compress(iproc.float2byte(raw_image))
  196. return compressed_image
  197. else
  198. return nil
  199. end
  200. end
  201. local function cache_decompress(compressed_image)
  202. if compressed_image then
  203. local raw_image = compression.decompress(compressed_image)
  204. return iproc.byte2float(raw_image)
  205. else
  206. return nil
  207. end
  208. end
  209. local function cache_get(filename)
  210. local cache = g_cache[filename]
  211. if cache then
  212. return {image = cache_decompress(cache.image),
  213. alpha = cache_decompress(cache.alpha)}
  214. else
  215. return nil
  216. end
  217. end
  218. local function cache_put(filename, image, alpha)
  219. g_cache[filename] = {image = cache_compress(image),
  220. alpha = cache_compress(alpha),
  221. updated_at = os.time()};
  222. local count = cache_count(g_cache)
  223. if count > opt.cache_max then
  224. cache_remove_old()
  225. end
  226. end
  227. local function convert(x, meta, options)
  228. local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
  229. local alpha = meta.alpha
  230. local alpha_orig = alpha
  231. local cache = cache_get(cache_file)
  232. if cache then
  233. meta = tablex.copy(meta)
  234. meta.alpha = cache.alpha
  235. return cache.image, meta
  236. else
  237. local model = nil
  238. if options.style == "art" then
  239. model = art_model
  240. elseif options.style == "photo" then
  241. model = photo_model
  242. end
  243. if options.border then
  244. x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(model.scale))
  245. end
  246. if (options.method == "scale" or
  247. options.method == "noise0_scale" or
  248. options.method == "noise1_scale" or
  249. options.method == "noise2_scale" or
  250. options.method == "noise3_scale")
  251. then
  252. x = reconstruct.scale_tta(model[options.method], options.tta_level, 2.0, x,
  253. opt.crop_size, opt.batch_size)
  254. if alpha then
  255. if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
  256. alpha = reconstruct.scale(model.scale, 2.0, alpha,
  257. opt.crop_size, opt.batch_size)
  258. cleanup_model(model.scale)
  259. end
  260. end
  261. cleanup_model(model[options.method])
  262. elseif (options.method == "noise0" or
  263. options.method == "noise1" or
  264. options.method == "noise2" or
  265. options.method == "noise3")
  266. then
  267. x = reconstruct.image_tta(model[options.method], options.tta_level,
  268. x, opt.crop_size, opt.batch_size)
  269. cleanup_model(model[options.method])
  270. end
  271. cache_put(cache_file, x, alpha)
  272. meta = tablex.copy(meta)
  273. meta.alpha = alpha
  274. return x, meta
  275. end
  276. end
  277. local function client_disconnected(handler)
  278. return not(handler.request and
  279. handler.request.connection and
  280. handler.request.connection.stream and
  281. (not handler.request.connection.stream:closed()))
  282. end
  283. local function make_output_filename(filename, mode)
  284. local e = path.extension(filename)
  285. local base = filename:sub(0, filename:len() - e:len())
  286. if mode then
  287. return base .. "_waifu2x_" .. mode .. ".png"
  288. else
  289. return base .. ".png"
  290. end
  291. end
  292. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  293. function APIHandler:post()
  294. if client_disconnected(self) then
  295. self:set_status(400)
  296. self:write("client disconnected")
  297. return
  298. end
  299. local x, meta, filename = get_image(self)
  300. local scale = tonumber(self:get_argument("scale", "-1"))
  301. local noise = tonumber(self:get_argument("noise", "-1"))
  302. local tta_level = tonumber(self:get_argument("tta_level", "1"))
  303. local style = self:get_argument("style", "art")
  304. local download = (self:get_argument("download", "")):len()
  305. if client_disconnected(self) then
  306. self:set_status(400)
  307. self:write("client disconnected")
  308. return
  309. end
  310. if opt.enable_tta then
  311. if tta_level == 0 then
  312. tta_level = auto_tta_level(x, scale)
  313. end
  314. if not (tta_level == 0 or tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
  315. tta_level = 1
  316. end
  317. else
  318. tta_level = 1
  319. end
  320. if style ~= "art" then
  321. style = "photo" -- style must be art or photo
  322. end
  323. if x and valid_size(x, scale, tta_level) then
  324. local prefix = nil
  325. if (noise >= 0 or scale > 0) then
  326. local hash = md5.sumhexa(meta.blob)
  327. local alpha_prefix = style .. "_" .. hash .. "_alpha"
  328. local border = false
  329. if scale >= 0 and meta.alpha then
  330. border = true
  331. end
  332. if (scale == 1 or scale == 2) and (noise < 0) then
  333. prefix = style .. "_scale_tta_" .. tta_level .. "_"
  334. x, meta = convert(x, meta, {method = "scale",
  335. style = style,
  336. tta_level = tta_level,
  337. prefix = prefix .. hash,
  338. alpha_prefix = alpha_prefix,
  339. border = border})
  340. if scale == 1 then
  341. x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
  342. end
  343. elseif (scale == 1 or scale == 2) and (noise == 0 or noise == 1 or noise == 2 or noise == 3) then
  344. prefix = style .. string.format("_noise%d_scale_tta_", noise) .. tta_level .. "_"
  345. x, meta = convert(x, meta, {method = string.format("noise%d_scale", noise),
  346. style = style,
  347. tta_level = tta_level,
  348. prefix = prefix .. hash,
  349. alpha_prefix = alpha_prefix,
  350. border = border})
  351. if scale == 1 then
  352. x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
  353. end
  354. elseif (noise == 0 or noise == 1 or noise == 2 or noise == 3) then
  355. prefix = style .. string.format("_noise%d_tta_", noise) .. tta_level .. "_"
  356. x = convert(x, meta, {method = string.format("noise%d", noise),
  357. style = style,
  358. tta_level = tta_level,
  359. prefix = prefix .. hash,
  360. alpha_prefix = alpha_prefix,
  361. border = border})
  362. border = false
  363. end
  364. end
  365. local name = nil
  366. if filename then
  367. if prefix then
  368. name = make_output_filename(filename, prefix:sub(0, prefix:len()-1))
  369. else
  370. name = make_output_filename(filename, nil)
  371. end
  372. else
  373. name = uuid() .. ".png"
  374. end
  375. local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
  376. tablex.update({depth = PNG_DEPTH, inplace = true}, meta))
  377. self:set_header("Content-Length", string.format("%d", #blob))
  378. if download > 0 then
  379. self:set_header("Content-Type", "application/octet-stream")
  380. self:set_header("Content-Disposition", string.format('attachment; filename="%s"', name))
  381. else
  382. self:set_header("Content-Type", "image/png")
  383. self:set_header("Content-Disposition", string.format('inline; filename="%s"', name))
  384. end
  385. self:write(blob)
  386. else
  387. if not x then
  388. self:set_status(400)
  389. self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
  390. else
  391. self:set_status(400)
  392. self:write("ERROR: image size exceeds maximum allowable size.")
  393. end
  394. end
  395. collectgarbage()
  396. end
  397. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  398. local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
  399. local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
  400. local index_pt = file.read(path.join(ROOT, "assets", "index.pt.html"))
  401. local index_es = file.read(path.join(ROOT, "assets", "index.es.html"))
  402. local index_fr = file.read(path.join(ROOT, "assets", "index.fr.html"))
  403. local index_de = file.read(path.join(ROOT, "assets", "index.de.html"))
  404. local index_tr = file.read(path.join(ROOT, "assets", "index.tr.html"))
  405. local index_zh_cn = file.read(path.join(ROOT, "assets", "index.zh-CN.html"))
  406. local index_zh_tw = file.read(path.join(ROOT, "assets", "index.zh-TW.html"))
  407. local index_ko = file.read(path.join(ROOT, "assets", "index.ko.html"))
  408. local index_nl = file.read(path.join(ROOT, "assets", "index.nl.html"))
  409. local index_ca = file.read(path.join(ROOT, "assets", "index.ca.html"))
  410. local index_ro = file.read(path.join(ROOT, "assets", "index.ro.html"))
  411. local index_it = file.read(path.join(ROOT, "assets", "index.it.html"))
  412. local index_eo = file.read(path.join(ROOT, "assets", "index.eo.html"))
  413. local index_no = file.read(path.join(ROOT, "assets", "index.no.html"))
  414. local index_uk = file.read(path.join(ROOT, "assets", "index.uk.html"))
  415. local index_pl = file.read(path.join(ROOT, "assets", "index.pl.html"))
  416. local index_bg = file.read(path.join(ROOT, "assets", "index.bg.html"))
  417. local index_en = file.read(path.join(ROOT, "assets", "index.html"))
  418. function FormHandler:get()
  419. local lang = self.request.headers:get("Accept-Language")
  420. if lang then
  421. local langs = utils.split(lang, ",")
  422. for i = 1, #langs do
  423. langs[i] = utils.split(langs[i], ";")[1]
  424. end
  425. if langs[1] == "ja" then
  426. self:write(index_ja)
  427. elseif langs[1] == "ru" then
  428. self:write(index_ru)
  429. elseif langs[1] == "pt" or langs[1] == "pt-BR" then
  430. self:write(index_pt)
  431. elseif langs[1] == "es" or langs[1] == "es-ES" then
  432. self:write(index_es)
  433. elseif langs[1] == "fr" then
  434. self:write(index_fr)
  435. elseif langs[1] == "de" then
  436. self:write(index_de)
  437. elseif langs[1] == "tr" then
  438. self:write(index_tr)
  439. elseif langs[1] == "zh-CN" or langs[1] == "zh" then
  440. self:write(index_zh_cn)
  441. elseif langs[1] == "zh-TW" then
  442. self:write(index_zh_tw)
  443. elseif langs[1] == "ko" then
  444. self:write(index_ko)
  445. elseif langs[1] == "nl" then
  446. self:write(index_nl)
  447. elseif langs[1] == "ca" or langs[1] == "ca-ES" or langs[1] == "ca-FR" or langs[1] == "ca-IT" or langs[1] == "ca-AD" then
  448. self:write(index_ca)
  449. elseif langs[1] == "ro" then
  450. self:write(index_ro)
  451. elseif langs[1] == "it" then
  452. self:write(index_it)
  453. elseif langs[1] == "eo" then
  454. self:write(index_eo)
  455. elseif langs[1] == "no" then
  456. self:write(index_no)
  457. elseif langs[1] == "uk" then
  458. self:write(index_uk)
  459. elseif langs[1] == "pl" then
  460. self:write(index_pl)
  461. elseif langs[1] == "bg" then
  462. self:write(index_bg)
  463. else
  464. self:write(index_en)
  465. end
  466. else
  467. self:write(index_en)
  468. end
  469. end
  470. turbo.log.categories = {
  471. ["success"] = true,
  472. ["notice"] = false,
  473. ["warning"] = true,
  474. ["error"] = true,
  475. ["debug"] = false,
  476. ["development"] = false
  477. }
  478. local app = turbo.web.Application:new(
  479. {
  480. {"^/$", FormHandler},
  481. {"^/api$", APIHandler},
  482. {"^/([%a%d%.%-_]+)$", turbo.web.StaticFileHandler, path.join(ROOT, "assets/")},
  483. }
  484. )
  485. app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  486. turbo.ioloop.instance():start()