|
@@ -1,10 +1,10 @@
|
|
|
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
|
|
-package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
|
|
+require 'pl'
|
|
|
+local ROOT = path.dirname(__FILE__)
|
|
|
+package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
|
|
|
_G.TURBO_SSL = true
|
|
|
|
|
|
-require 'pl'
|
|
|
require 'w2nn'
|
|
|
-local turbo = require 'turbo'
|
|
|
local uuid = require 'uuid'
|
|
|
local ffi = require 'ffi'
|
|
|
local md5 = require 'md5'
|
|
@@ -12,6 +12,11 @@ local iproc = require 'iproc'
|
|
|
local reconstruct = require 'reconstruct'
|
|
|
local image_loader = require 'image_loader'
|
|
|
|
|
|
+-- Notes: turbo and xlua has different implementation of string:split().
|
|
|
+-- Therefore, string:split() has conflict issue.
|
|
|
+-- In this script, use turbo's string:split().
|
|
|
+local turbo = require 'turbo'
|
|
|
+
|
|
|
local cmd = torch.CmdLine()
|
|
|
cmd:text()
|
|
|
cmd:text("waifu2x-api")
|
|
@@ -29,14 +34,14 @@ if cudnn then
|
|
|
cudnn.fastest = true
|
|
|
cudnn.benchmark = false
|
|
|
end
|
|
|
+local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
|
|
|
+local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
|
|
|
+local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
|
|
|
+local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
|
|
|
+local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
|
|
+local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
|
|
|
|
|
-local MODEL_DIR = "./models/anime_style_art_rgb"
|
|
|
-local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
|
|
|
-local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
|
|
|
-local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
|
|
-
|
|
|
-local USE_CACHE = true
|
|
|
-local CACHE_DIR = "./cache"
|
|
|
+local CACHE_DIR = path.join(ROOT, "cache")
|
|
|
local MAX_NOISE_IMAGE = 2560 * 2560
|
|
|
local MAX_SCALE_IMAGE = 1280 * 1280
|
|
|
local CURL_OPTIONS = {
|
|
@@ -55,15 +60,6 @@ local function valid_size(x, scale)
|
|
|
end
|
|
|
end
|
|
|
|
|
|
-local function apply_denoise1(x)
|
|
|
- return reconstruct.image(noise1_model, x)
|
|
|
-end
|
|
|
-local function apply_denoise2(x)
|
|
|
- return reconstruct.image(noise2_model, x)
|
|
|
-end
|
|
|
-local function apply_scale2x(x)
|
|
|
- return reconstruct.scale(scale20_model, 2.0, x)
|
|
|
-end
|
|
|
local function cache_url(url)
|
|
|
local hash = md5.sumhexa(url)
|
|
|
local cache_file = path.join(CACHE_DIR, "url_" .. hash)
|
|
@@ -91,15 +87,6 @@ local function cache_url(url)
|
|
|
end
|
|
|
return nil, nil, nil
|
|
|
end
|
|
|
-local function cache_do(cache, x, func)
|
|
|
- if path.exists(cache) then
|
|
|
- return image.load(cache)
|
|
|
- else
|
|
|
- x = func(x)
|
|
|
- image.save(cache, x)
|
|
|
- return x
|
|
|
- end
|
|
|
-end
|
|
|
local function get_image(req)
|
|
|
local file = req:get_argument("file", "")
|
|
|
local url = req:get_argument("url", "")
|
|
@@ -114,7 +101,30 @@ local function get_image(req)
|
|
|
end
|
|
|
return nil, nil, nil
|
|
|
end
|
|
|
-
|
|
|
+local function convert(x, options)
|
|
|
+ local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
|
|
|
+ if path.exists(cache_file) then
|
|
|
+ return image.load(cache_file)
|
|
|
+ else
|
|
|
+ if options.style == "art" then
|
|
|
+ if options.method == "scale" then
|
|
|
+ x = reconstruct.scale(art_scale2_model, 2.0, x)
|
|
|
+ w2nn.cleanup_model(art_scale2_model)
|
|
|
+ elseif options.method == "noise1" then
|
|
|
+ x = reconstruct.image(art_noise1_model, x)
|
|
|
+ w2nn.cleanup_model(art_noise1_model)
|
|
|
+ else -- options.method == "noise2"
|
|
|
+ x = reconstruct.image(art_noise2_model, x)
|
|
|
+ w2nn.cleanup_model(art_noise2_model)
|
|
|
+ end
|
|
|
+ else -- photo
|
|
|
+ x = reconstruct.scale(photo_scale2_model, 2.0, x)
|
|
|
+ w2nn.cleanup_model(photo_scale2_model)
|
|
|
+ end
|
|
|
+ image.save(cache_file, x)
|
|
|
+ return x
|
|
|
+ end
|
|
|
+end
|
|
|
local function client_disconnected(handler)
|
|
|
return not(handler.request and
|
|
|
handler.request.connection and
|
|
@@ -129,30 +139,28 @@ function APIHandler:post()
|
|
|
self:write("client disconnected")
|
|
|
return
|
|
|
end
|
|
|
- local x, alpha, src = get_image(self)
|
|
|
+ local x, alpha, blob = get_image(self)
|
|
|
local scale = tonumber(self:get_argument("scale", "0"))
|
|
|
local noise = tonumber(self:get_argument("noise", "0"))
|
|
|
+ local style = self:get_argument("style", "art")
|
|
|
+ if style ~= "art" then
|
|
|
+ style = "photo" -- style must be art or photo
|
|
|
+ end
|
|
|
if x and valid_size(x, scale) then
|
|
|
- if USE_CACHE and (noise ~= 0 or scale ~= 0) then
|
|
|
- local hash = md5.sumhexa(src)
|
|
|
- local cache_noise1 = path.join(CACHE_DIR, hash .. "_noise1.png")
|
|
|
- local cache_noise2 = path.join(CACHE_DIR, hash .. "_noise2.png")
|
|
|
- local cache_scale = path.join(CACHE_DIR, hash .. "_scale.png")
|
|
|
- local cache_noise1_scale = path.join(CACHE_DIR, hash .. "_noise1_scale.png")
|
|
|
- local cache_noise2_scale = path.join(CACHE_DIR, hash .. "_noise2_scale.png")
|
|
|
-
|
|
|
+ if (noise ~= 0 or scale ~= 0) then
|
|
|
+ local hash = md5.sumhexa(blob)
|
|
|
if noise == 1 then
|
|
|
- x = cache_do(cache_noise1, x, apply_denoise1)
|
|
|
+ x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash})
|
|
|
elseif noise == 2 then
|
|
|
- x = cache_do(cache_noise2, x, apply_denoise2)
|
|
|
+ x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash})
|
|
|
end
|
|
|
if scale == 1 or scale == 2 then
|
|
|
if noise == 1 then
|
|
|
- x = cache_do(cache_noise1_scale, x, apply_scale2x)
|
|
|
+ x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash})
|
|
|
elseif noise == 2 then
|
|
|
- x = cache_do(cache_noise2_scale, x, apply_scale2x)
|
|
|
+ x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash})
|
|
|
else
|
|
|
- x = cache_do(cache_scale, x, apply_scale2x)
|
|
|
+ x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash})
|
|
|
end
|
|
|
if scale == 1 then
|
|
|
x = iproc.scale(x,
|
|
@@ -161,23 +169,9 @@ function APIHandler:post()
|
|
|
"Jinc")
|
|
|
end
|
|
|
end
|
|
|
- elseif noise ~= 0 or scale ~= 0 then
|
|
|
- if noise == 1 then
|
|
|
- x = apply_denoise1(x)
|
|
|
- elseif noise == 2 then
|
|
|
- x = apply_denoise2(x)
|
|
|
- end
|
|
|
- if scale == 1 then
|
|
|
- local x16 = {math.floor(x:size(3) * 1.6 + 0.5), math.floor(x:size(2) * 1.6 + 0.5)}
|
|
|
- x = apply_scale2x(x)
|
|
|
- x = iproc.scale(x, x16[1], x16[2], "Jinc")
|
|
|
- elseif scale == 2 then
|
|
|
- x = apply_scale2x(x)
|
|
|
- end
|
|
|
end
|
|
|
local name = uuid() .. ".png"
|
|
|
local blob, len = image_loader.encode_png(x, alpha)
|
|
|
-
|
|
|
self:set_header("Content-Disposition", string.format('filename="%s"', name))
|
|
|
self:set_header("Content-Type", "image/png")
|
|
|
self:set_header("Content-Length", string.format("%d", len))
|
|
@@ -194,9 +188,9 @@ function APIHandler:post()
|
|
|
collectgarbage()
|
|
|
end
|
|
|
local FormHandler = class("FormHandler", turbo.web.RequestHandler)
|
|
|
-local index_ja = file.read("./assets/index.ja.html")
|
|
|
-local index_ru = file.read("./assets/index.ru.html")
|
|
|
-local index_en = file.read("./assets/index.html")
|
|
|
+local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
|
|
|
+local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
|
|
|
+local index_en = file.read(path.join(ROOT, "assets", "index.html"))
|
|
|
function FormHandler:get()
|
|
|
local lang = self.request.headers:get("Accept-Language")
|
|
|
if lang then
|
|
@@ -226,9 +220,11 @@ turbo.log.categories = {
|
|
|
local app = turbo.web.Application:new(
|
|
|
{
|
|
|
{"^/$", FormHandler},
|
|
|
- {"^/index.html", turbo.web.StaticFileHandler, path.join("./assets", "index.html")},
|
|
|
- {"^/index.ja.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ja.html")},
|
|
|
- {"^/index.ru.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ru.html")},
|
|
|
+ {"^/style.css", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "style.css")},
|
|
|
+ {"^/ui.js", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "ui.js")},
|
|
|
+ {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")},
|
|
|
+ {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")},
|
|
|
+ {"^/index.ru.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ru.html")},
|
|
|
{"^/api$", APIHandler},
|
|
|
}
|
|
|
)
|