|
@@ -12,6 +12,7 @@ local iproc = require 'iproc'
|
|
|
local reconstruct = require 'reconstruct'
|
|
|
local image_loader = require 'image_loader'
|
|
|
local alpha_util = require 'alpha_util'
|
|
|
+local compression = require 'compression'
|
|
|
local gm = require 'graphicsmagick'
|
|
|
|
|
|
-- Note: turbo and xlua has different implementation of string:split().
|
|
@@ -34,6 +35,7 @@ cmd:option("-curl_request_timeout", 60, "request_timeout for curl")
|
|
|
cmd:option("-curl_connect_timeout", 60, "connect_timeout for curl")
|
|
|
cmd:option("-curl_max_redirects", 2, "max_redirects for curl")
|
|
|
cmd:option("-max_body_size", 5 * 1024 * 1024, "maximum allowed size for uploaded files")
|
|
|
+cmd:option("-cache_max", 200, "number of cached images on RAM")
|
|
|
|
|
|
local opt = cmd:parse(arg)
|
|
|
cutorch.setDevice(opt.gpu)
|
|
@@ -75,6 +77,7 @@ local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could us
|
|
|
local CACHE_DIR = path.join(ROOT, "cache")
|
|
|
local MAX_NOISE_IMAGE = opt.max_pixels
|
|
|
local MAX_SCALE_IMAGE = (math.sqrt(opt.max_pixels) / 2)^2
|
|
|
+local PNG_DEPTH = 8
|
|
|
local CURL_OPTIONS = {
|
|
|
request_timeout = opt.curl_request_timeout,
|
|
|
connect_timeout = opt.curl_connect_timeout,
|
|
@@ -167,26 +170,73 @@ local function cleanup_model(model)
|
|
|
model:clearState() -- release GPU memory
|
|
|
end
|
|
|
end
|
|
|
+
|
|
|
+-- cache
|
|
|
+local g_cache = {}
|
|
|
+local function cache_count()
|
|
|
+ local count = 0
|
|
|
+ for _ in pairs(g_cache) do
|
|
|
+ count = count + 1
|
|
|
+ end
|
|
|
+ return count
|
|
|
+end
|
|
|
+local function cache_remove_old()
|
|
|
+ local old_time = nil
|
|
|
+ local old_key = nil
|
|
|
+ for k, v in pairs(g_cache) do
|
|
|
+ if old_time == nil or old_time > v.updated_at then
|
|
|
+ old_key = k
|
|
|
+ old_time = v.updated_at
|
|
|
+ end
|
|
|
+ end
|
|
|
+ if old_key then
|
|
|
+ g_cache[old_key] = nil
|
|
|
+ end
|
|
|
+end
|
|
|
+local function cache_compress(raw_image)
|
|
|
+ if raw_image then
|
|
|
+ compressed_image = compression.compress(iproc.float2byte(raw_image))
|
|
|
+ return compressed_image
|
|
|
+ else
|
|
|
+ return nil
|
|
|
+ end
|
|
|
+end
|
|
|
+local function cache_decompress(compressed_image)
|
|
|
+ if compressed_image then
|
|
|
+ local raw_image = compression.decompress(compressed_image)
|
|
|
+ return iproc.byte2float(raw_image)
|
|
|
+ else
|
|
|
+ return nil
|
|
|
+ end
|
|
|
+end
|
|
|
+local function cache_get(filename)
|
|
|
+ local cache = g_cache[filename]
|
|
|
+ if cache then
|
|
|
+ return {image = cache_decompress(cache.image),
|
|
|
+ alpha = cache_decompress(cache.alpha)}
|
|
|
+ else
|
|
|
+ return nil
|
|
|
+ end
|
|
|
+end
|
|
|
+local function cache_put(filename, image, alpha)
|
|
|
+ g_cache[filename] = {image = cache_compress(image),
|
|
|
+ alpha = cache_compress(alpha),
|
|
|
+ updated_at = os.time()};
|
|
|
+ local count = cache_count(g_cache)
|
|
|
+ if count > opt.cache_max then
|
|
|
+ cache_remove_old()
|
|
|
+ end
|
|
|
+end
|
|
|
local function convert(x, meta, options)
|
|
|
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
|
|
|
- local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
|
|
|
local alpha = meta.alpha
|
|
|
local alpha_orig = alpha
|
|
|
+ local cache = cache_get(cache_file)
|
|
|
|
|
|
- if path.exists(alpha_cache_file) then
|
|
|
- alpha = image_loader.load_float(alpha_cache_file)
|
|
|
- if alpha:dim() == 2 then
|
|
|
- alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
|
|
|
- end
|
|
|
- if alpha:size(1) == 3 then
|
|
|
- alpha = image.rgb2y(alpha)
|
|
|
- end
|
|
|
- end
|
|
|
- if path.exists(cache_file) then
|
|
|
- x = image_loader.load_float(cache_file)
|
|
|
+ if cache then
|
|
|
meta = tablex.copy(meta)
|
|
|
- meta.alpha = alpha
|
|
|
- return x, meta
|
|
|
+ meta.alpha = cache.alpha
|
|
|
+ return cache.image, meta
|
|
|
else
|
|
|
local model = nil
|
|
|
if options.style == "art" then
|
|
@@ -209,7 +259,6 @@ local function convert(x, meta, options)
|
|
|
if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
|
|
|
alpha = reconstruct.scale(model.scale, 2.0, alpha,
|
|
|
opt.crop_size, opt.batch_size)
|
|
|
- image_loader.save_png(alpha_cache_file, alpha)
|
|
|
cleanup_model(model.scale)
|
|
|
end
|
|
|
end
|
|
@@ -223,7 +272,7 @@ local function convert(x, meta, options)
|
|
|
x, opt.crop_size, opt.batch_size)
|
|
|
cleanup_model(model[options.method])
|
|
|
end
|
|
|
- image_loader.save_png(cache_file, x)
|
|
|
+ cache_put(cache_file, x, alpha)
|
|
|
meta = tablex.copy(meta)
|
|
|
meta.alpha = alpha
|
|
|
return x, meta
|
|
@@ -259,6 +308,12 @@ function APIHandler:post()
|
|
|
local style = self:get_argument("style", "art")
|
|
|
local download = (self:get_argument("download", "")):len()
|
|
|
|
|
|
+ if client_disconnected(self) then
|
|
|
+ self:set_status(400)
|
|
|
+ self:write("client disconnected")
|
|
|
+ return
|
|
|
+ end
|
|
|
+
|
|
|
if tta_level == 0 then
|
|
|
tta_level = auto_tta_level(x, scale)
|
|
|
end
|
|
@@ -321,7 +376,7 @@ function APIHandler:post()
|
|
|
name = uuid() .. ".png"
|
|
|
end
|
|
|
local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
|
|
|
- tablex.update({depth = 8, inplace = true}, meta))
|
|
|
+ tablex.update({depth = PNG_DEPTH, inplace = true}, meta))
|
|
|
|
|
|
self:set_header("Content-Length", string.format("%d", #blob))
|
|
|
if download > 0 then
|
|
@@ -354,6 +409,8 @@ local index_tr = file.read(path.join(ROOT, "assets", "index.tr.html"))
|
|
|
local index_zh_cn = file.read(path.join(ROOT, "assets", "index.zh-CN.html"))
|
|
|
local index_zh_tw = file.read(path.join(ROOT, "assets", "index.zh-TW.html"))
|
|
|
local index_ko = file.read(path.join(ROOT, "assets", "index.ko.html"))
|
|
|
+local index_nl = file.read(path.join(ROOT, "assets", "index.nl.html"))
|
|
|
+local index_ca = file.read(path.join(ROOT, "assets", "index.ca.html"))
|
|
|
local index_en = file.read(path.join(ROOT, "assets", "index.html"))
|
|
|
function FormHandler:get()
|
|
|
local lang = self.request.headers:get("Accept-Language")
|
|
@@ -382,6 +439,10 @@ function FormHandler:get()
|
|
|
self:write(index_zh_tw)
|
|
|
elseif langs[1] == "ko" then
|
|
|
self:write(index_ko)
|
|
|
+ elseif langs[1] == "nl" then
|
|
|
+ self:write(index_nl)
|
|
|
+ elseif langs[1] == "ca" or langs[1] == "ca-ES" or langs[1] == "ca-FR" or langs[1] == "ca-IT" or langs[1] == "ca-AD" then
|
|
|
+ self:write(index_ca)
|
|
|
else
|
|
|
self:write(index_en)
|
|
|
end
|