web.lua 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. _G.TURBO_SSL = true
  2. local turbo = require 'turbo'
  3. local uuid = require 'uuid'
  4. local ffi = require 'ffi'
  5. local md5 = require 'md5'
  6. require 'pl'
  7. require 'lib.portable'
  8. require 'lib.mynn'
  9. local cmd = torch.CmdLine()
  10. cmd:text()
  11. cmd:text("waifu2x-api")
  12. cmd:text("Options:")
  13. cmd:option("-port", 8812, 'listen port')
  14. cmd:option("-gpu", 1, 'Device ID')
  15. cmd:option("-core", 2, 'number of CPU cores')
  16. local opt = cmd:parse(arg)
  17. cutorch.setDevice(opt.gpu)
  18. torch.setdefaulttensortype('torch.FloatTensor')
  19. torch.setnumthreads(opt.core)
  20. local iproc = require './lib/iproc'
  21. local reconstruct = require './lib/reconstruct'
  22. local image_loader = require './lib/image_loader'
  23. local MODEL_DIR = "./models/anime_style_art_rgb3"
  24. local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
  25. local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
  26. local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")
  27. local USE_CACHE = true
  28. local CACHE_DIR = "./cache"
  29. local MAX_NOISE_IMAGE = 2560 * 2560
  30. local MAX_SCALE_IMAGE = 1280 * 1280
  31. local CURL_OPTIONS = {
  32. request_timeout = 15,
  33. connect_timeout = 10,
  34. allow_redirects = true,
  35. max_redirects = 2
  36. }
  37. local CURL_MAX_SIZE = 2 * 1024 * 1024
  38. local function valid_size(x, scale)
  39. if scale == 0 then
  40. return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE
  41. else
  42. return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE
  43. end
  44. end
  45. local function get_image(req)
  46. local file = req:get_argument("file", "")
  47. local url = req:get_argument("url", "")
  48. local blob = nil
  49. local img = nil
  50. local alpha = nil
  51. if file and file:len() > 0 then
  52. blob = file
  53. img, alpha = image_loader.decode_float(blob)
  54. elseif url and url:len() > 0 then
  55. local res = coroutine.yield(
  56. turbo.async.HTTPClient({verify_ca=false},
  57. nil,
  58. CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
  59. )
  60. if res.code == 200 then
  61. local content_type = res.headers:get("Content-Type", true)
  62. if type(content_type) == "table" then
  63. content_type = content_type[1]
  64. end
  65. if content_type and content_type:find("image") then
  66. blob = res.body
  67. img, alpha = image_loader.decode_float(blob)
  68. end
  69. end
  70. end
  71. return img, blob, alpha
  72. end
  73. local function apply_denoise1(x)
  74. return reconstruct.image(noise1_model, x)
  75. end
  76. local function apply_denoise2(x)
  77. return reconstruct.image(noise2_model, x)
  78. end
  79. local function apply_scale2x(x)
  80. return reconstruct.scale(scale20_model, 2.0, x)
  81. end
  82. local function cache_do(cache, x, func)
  83. if path.exists(cache) then
  84. return image.load(cache)
  85. else
  86. x = func(x)
  87. image.save(cache, x)
  88. return x
  89. end
  90. end
  91. local function client_disconnected(handler)
  92. return not(handler.request and
  93. handler.request.connection and
  94. handler.request.connection.stream and
  95. (not handler.request.connection.stream:closed()))
  96. end
  97. local APIHandler = class("APIHandler", turbo.web.RequestHandler)
  98. function APIHandler:post()
  99. if client_disconnected(self) then
  100. self:set_status(400)
  101. self:write("client disconnected")
  102. return
  103. end
  104. local x, src, alpha = get_image(self)
  105. local scale = tonumber(self:get_argument("scale", "0"))
  106. local noise = tonumber(self:get_argument("noise", "0"))
  107. if x and valid_size(x, scale) then
  108. if USE_CACHE and (noise ~= 0 or scale ~= 0) then
  109. local hash = md5.sumhexa(src)
  110. local cache_noise1 = path.join(CACHE_DIR, hash .. "_noise1.png")
  111. local cache_noise2 = path.join(CACHE_DIR, hash .. "_noise2.png")
  112. local cache_scale = path.join(CACHE_DIR, hash .. "_scale.png")
  113. local cache_noise1_scale = path.join(CACHE_DIR, hash .. "_noise1_scale.png")
  114. local cache_noise2_scale = path.join(CACHE_DIR, hash .. "_noise2_scale.png")
  115. if noise == 1 then
  116. x = cache_do(cache_noise1, x, apply_denoise1)
  117. elseif noise == 2 then
  118. x = cache_do(cache_noise2, x, apply_denoise2)
  119. end
  120. if scale == 1 or scale == 2 then
  121. if noise == 1 then
  122. x = cache_do(cache_noise1_scale, x, apply_scale2x)
  123. elseif noise == 2 then
  124. x = cache_do(cache_noise2_scale, x, apply_scale2x)
  125. else
  126. x = cache_do(cache_scale, x, apply_scale2x)
  127. end
  128. if scale == 1 then
  129. x = iproc.scale(x,
  130. math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
  131. math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
  132. "Jinc")
  133. end
  134. end
  135. elseif noise ~= 0 or scale ~= 0 then
  136. if noise == 1 then
  137. x = apply_denoise1(x)
  138. elseif noise == 2 then
  139. x = apply_denoise2(x)
  140. end
  141. if scale == 1 then
  142. local x16 = {math.floor(x:size(3) * 1.6 + 0.5), math.floor(x:size(2) * 1.6 + 0.5)}
  143. x = apply_scale2x(x)
  144. x = iproc.scale(x, x16[1], x16[2], "Jinc")
  145. elseif scale == 2 then
  146. x = apply_scale2x(x)
  147. end
  148. end
  149. local name = uuid() .. ".png"
  150. local blob, len = image_loader.encode_png(x, alpha)
  151. self:set_header("Content-Disposition", string.format('filename="%s"', name))
  152. self:set_header("Content-Type", "image/png")
  153. self:set_header("Content-Length", string.format("%d", len))
  154. self:write(ffi.string(blob, len))
  155. else
  156. if not x then
  157. self:set_status(400)
  158. self:write("ERROR: unsupported image format.")
  159. else
  160. self:set_status(400)
  161. self:write("ERROR: image size exceeds maximum allowable size.")
  162. end
  163. end
  164. collectgarbage()
  165. end
  166. local FormHandler = class("FormHandler", turbo.web.RequestHandler)
  167. local index_ja = file.read("./assets/index.ja.html")
  168. local index_ru = file.read("./assets/index.ru.html")
  169. local index_en = file.read("./assets/index.html")
  170. function FormHandler:get()
  171. local lang = self.request.headers:get("Accept-Language")
  172. if lang then
  173. local langs = utils.split(lang, ",")
  174. for i = 1, #langs do
  175. langs[i] = utils.split(langs[i], ";")[1]
  176. end
  177. if langs[1] == "ja" then
  178. self:write(index_ja)
  179. elseif langs[1] == "ru" then
  180. self:write(index_ru)
  181. else
  182. self:write(index_en)
  183. end
  184. else
  185. self:write(index_en)
  186. end
  187. end
  188. turbo.log.categories = {
  189. ["success"] = true,
  190. ["notice"] = false,
  191. ["warning"] = true,
  192. ["error"] = true,
  193. ["debug"] = false,
  194. ["development"] = false
  195. }
  196. local app = turbo.web.Application:new(
  197. {
  198. {"^/$", FormHandler},
  199. {"^/index.html", turbo.web.StaticFileHandler, path.join("./assets", "index.html")},
  200. {"^/index.ja.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ja.html")},
  201. {"^/index.ru.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ru.html")},
  202. {"^/api$", APIHandler},
  203. }
  204. )
  205. app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE})
  206. turbo.ioloop.instance():start()