|
@@ -112,6 +112,24 @@ local function convert_image(opt)
|
|
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
|
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
+ elseif opt.m == "user" then
|
|
|
|
+ local model_path = opt.model_path
|
|
|
|
+ local model = w2nn.load_model(model_path, opt.force_cudnn)
|
|
|
|
+ if not model then
|
|
|
|
+ error("Load Error: " .. model_path)
|
|
|
|
+ end
|
|
|
|
+ local t = sys.clock()
|
|
|
|
+
|
|
|
|
+ x = alpha_util.make_border(x, alpha, reconstruct.offset_size(model))
|
|
|
|
+ if opt.scale == 1 then
|
|
|
|
+ new_x = image_f(model, x, opt.crop_size, opt.batch_size)
|
|
|
|
+ else
|
|
|
|
+ new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.batch_size)
|
|
|
|
+ end
|
|
|
|
+ new_x = alpha_util.composite(new_x, alpha) -- TODO: should it use model?
|
|
|
|
+ if not opt.q then
|
|
|
|
+ print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
|
|
|
+ end
|
|
else
|
|
else
|
|
error("undefined method:" .. opt.method)
|
|
error("undefined method:" .. opt.method)
|
|
end
|
|
end
|
|
@@ -121,6 +139,7 @@ local function convert_frames(opt)
|
|
local model_path, scale_model, t
|
|
local model_path, scale_model, t
|
|
local noise_scale_model = {}
|
|
local noise_scale_model = {}
|
|
local noise_model = {}
|
|
local noise_model = {}
|
|
|
|
+ local user_model = nil
|
|
local scale_f, image_f
|
|
local scale_f, image_f
|
|
if opt.tta == 1 then
|
|
if opt.tta == 1 then
|
|
scale_f = function(model, scale, x, block_size, batch_size)
|
|
scale_f = function(model, scale, x, block_size, batch_size)
|
|
@@ -156,6 +175,8 @@ local function convert_frames(opt)
|
|
model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
|
|
model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
|
|
noise_model[opt.noise_level] = w2nn.load_model(model_path, opt.force_cudnn)
|
|
noise_model[opt.noise_level] = w2nn.load_model(model_path, opt.force_cudnn)
|
|
end
|
|
end
|
|
|
|
+ elseif opt.m == "user" then
|
|
|
|
+ user_model = w2nn.load_model(opt.model_path, opt.force_cudnn)
|
|
end
|
|
end
|
|
local fp = io.open(opt.l)
|
|
local fp = io.open(opt.l)
|
|
if not fp then
|
|
if not fp then
|
|
@@ -189,6 +210,14 @@ local function convert_frames(opt)
|
|
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size)
|
|
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size)
|
|
end
|
|
end
|
|
new_x = alpha_util.composite(new_x, alpha, scale_model)
|
|
new_x = alpha_util.composite(new_x, alpha, scale_model)
|
|
|
|
+ elseif opt.m == "user" then
|
|
|
|
+ x = alpha_util.make_border(x, alpha, reconstruct.offset_size(user_model))
|
|
|
|
+ if opt.scale == 1 then
|
|
|
|
+ new_x = image_f(user_model, x, opt.crop_size, opt.batch_size)
|
|
|
|
+ else
|
|
|
|
+ new_x = scale_f(user_model, opt.scale, x, opt.crop_size, opt.batch_size)
|
|
|
|
+ end
|
|
|
|
+ new_x = alpha_util.composite(new_x, alpha)
|
|
else
|
|
else
|
|
error("undefined method:" .. opt.method)
|
|
error("undefined method:" .. opt.method)
|
|
end
|
|
end
|
|
@@ -218,7 +247,8 @@ local function waifu2x()
|
|
cmd:option("-o", "(auto)", 'path to output file')
|
|
cmd:option("-o", "(auto)", 'path to output file')
|
|
cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
|
|
cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
|
|
cmd:option("-model_dir", "./models/upconv_7/art", 'path to model directory')
|
|
cmd:option("-model_dir", "./models/upconv_7/art", 'path to model directory')
|
|
- cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
|
|
|
|
|
|
+ cmd:option("-name", "user", 'model name for user method')
|
|
|
|
+ cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale|user)')
|
|
cmd:option("-method", "", 'same as -m')
|
|
cmd:option("-method", "", 'same as -m')
|
|
cmd:option("-noise_level", 1, '(1|2|3)')
|
|
cmd:option("-noise_level", 1, '(1|2|3)')
|
|
cmd:option("-crop_size", 128, 'patch size per process')
|
|
cmd:option("-crop_size", 128, 'patch size per process')
|
|
@@ -247,6 +277,7 @@ local function waifu2x()
|
|
end
|
|
end
|
|
opt.force_cudnn = opt.force_cudnn == 1
|
|
opt.force_cudnn = opt.force_cudnn == 1
|
|
opt.q = opt.q == 1
|
|
opt.q = opt.q == 1
|
|
|
|
+ opt.model_path = path.join(opt.model_dir, string.format("%s_model.t7", opt.name))
|
|
|
|
|
|
if string.len(opt.l) == 0 then
|
|
if string.len(opt.l) == 0 then
|
|
convert_image(opt)
|
|
convert_image(opt)
|