|
@@ -18,7 +18,7 @@ cmd:option("-dir", "./data/test", 'test image directory')
|
|
cmd:option("-file", "", 'test image file list')
|
|
cmd:option("-file", "", 'test image file list')
|
|
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
|
|
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
|
|
cmd:option("-model2_dir", "", 'model2 directory (optional)')
|
|
cmd:option("-model2_dir", "", 'model2 directory (optional)')
|
|
-cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff)')
|
|
|
|
|
|
+cmd:option("-method", "scale", '(scale|noise|noise_scale|user|diff|scale4)')
|
|
cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))")
|
|
cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))")
|
|
cmd:option("-resize_blur", 1.0, 'blur parameter for resize')
|
|
cmd:option("-resize_blur", 1.0, 'blur parameter for resize')
|
|
cmd:option("-color", "y", '(rgb|y|r|g|b)')
|
|
cmd:option("-color", "y", '(rgb|y|r|g|b)')
|
|
@@ -154,12 +154,24 @@ local function baseline_scale(x, filter)
|
|
x:size(2) * 2.0,
|
|
x:size(2) * 2.0,
|
|
filter)
|
|
filter)
|
|
end
|
|
end
|
|
|
|
+local function baseline_scale4(x, filter)
|
|
|
|
+ return iproc.scale(x,
|
|
|
|
+ x:size(3) * 4.0,
|
|
|
|
+ x:size(2) * 4.0,
|
|
|
|
+ filter)
|
|
|
|
+end
|
|
local function transform_scale(x, opt)
|
|
local function transform_scale(x, opt)
|
|
return iproc.scale(x,
|
|
return iproc.scale(x,
|
|
x:size(3) * 0.5,
|
|
x:size(3) * 0.5,
|
|
x:size(2) * 0.5,
|
|
x:size(2) * 0.5,
|
|
opt.filter, opt.resize_blur)
|
|
opt.filter, opt.resize_blur)
|
|
end
|
|
end
|
|
|
|
+local function transform_scale4(x, opt)
|
|
|
|
+ return iproc.scale(x,
|
|
|
|
+ x:size(3) * 0.25,
|
|
|
|
+ x:size(2) * 0.25,
|
|
|
|
+ opt.filter, opt.resize_blur)
|
|
|
|
+end
|
|
|
|
|
|
local function transform_scale_jpeg(x, opt)
|
|
local function transform_scale_jpeg(x, opt)
|
|
x = iproc.scale(x,
|
|
x = iproc.scale(x,
|
|
@@ -237,6 +249,26 @@ local function benchmark(opt, x, model1, model2)
|
|
model2_time = model2_time + (sys.clock() - t)
|
|
model2_time = model2_time + (sys.clock() - t)
|
|
end
|
|
end
|
|
baseline_output = baseline_scale(input, opt.baseline_filter)
|
|
baseline_output = baseline_scale(input, opt.baseline_filter)
|
|
|
|
+ elseif opt.method == "scale4" then
|
|
|
|
+ input = transform_scale4(x[i].y, opt)
|
|
|
|
+ ground_truth = x[i].y
|
|
|
|
+ if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
|
|
|
|
+ model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
|
|
|
|
+ if model2 then
|
|
|
|
+ model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+ t = sys.clock()
|
|
|
|
+ model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
|
|
|
|
+ model1_output = scale_f(model1, 2.0, model1_output, opt.crop_size, opt.batch_size)
|
|
|
|
+ model1_time = model1_time + (sys.clock() - t)
|
|
|
|
+ if model2 then
|
|
|
|
+ t = sys.clock()
|
|
|
|
+ model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
|
|
|
|
+ model2_output = scale_f(model2, 2.0, model2_output, opt.crop_size, opt.batch_size)
|
|
|
|
+ model2_time = model2_time + (sys.clock() - t)
|
|
|
|
+ end
|
|
|
|
+ baseline_output = baseline_scale4(input, opt.baseline_filter)
|
|
elseif opt.method == "noise" then
|
|
elseif opt.method == "noise" then
|
|
input = transform_jpeg(x[i].y, opt)
|
|
input = transform_jpeg(x[i].y, opt)
|
|
ground_truth = x[i].y
|
|
ground_truth = x[i].y
|
|
@@ -604,7 +636,7 @@ if opt.show_progress then
|
|
print(opt)
|
|
print(opt)
|
|
end
|
|
end
|
|
|
|
|
|
-if opt.method == "scale" then
|
|
|
|
|
|
+if opt.method == "scale" or opt.method == "scale4" then
|
|
local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
|
|
local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
|
|
local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
|
|
local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
|
|
local s1, model1 = pcall(w2nn.load_model, f1, opt.force_cudnn)
|
|
local s1, model1 = pcall(w2nn.load_model, f1, opt.force_cudnn)
|