|
@@ -14,8 +14,8 @@ cmd:text("waifu2x-benchmark")
|
|
cmd:text("Options:")
|
|
cmd:text("Options:")
|
|
|
|
|
|
cmd:option("-dir", "./data/test", 'test image directory')
|
|
cmd:option("-dir", "./data/test", 'test image directory')
|
|
-cmd:option("-model1_dir", "./models/anime_style_art", 'model1 directory')
|
|
|
|
-cmd:option("-model2_dir", "./models/anime_style_art_rgb", 'model2 directory')
|
|
|
|
|
|
+cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
|
|
|
|
+cmd:option("-model2_dir", "", 'model2 directory (optional)')
|
|
cmd:option("-method", "scale", '(scale|noise)')
|
|
cmd:option("-method", "scale", '(scale|noise)')
|
|
cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)")
|
|
cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)")
|
|
cmd:option("-color", "rgb", '(rgb|y)')
|
|
cmd:option("-color", "rgb", '(rgb|y)')
|
|
@@ -83,32 +83,46 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|
t = sys.clock()
|
|
t = sys.clock()
|
|
if input:size(3) == ground_truth:size(3) then
|
|
if input:size(3) == ground_truth:size(3) then
|
|
model1_output = reconstruct.image(model1, input)
|
|
model1_output = reconstruct.image(model1, input)
|
|
- model2_output = reconstruct.image(model2, input)
|
|
|
|
|
|
+ if model2 then
|
|
|
|
+ model2_output = reconstruct.image(model2, input)
|
|
|
|
+ end
|
|
else
|
|
else
|
|
model1_output = reconstruct.scale(model1, 2.0, input)
|
|
model1_output = reconstruct.scale(model1, 2.0, input)
|
|
- model2_output = reconstruct.scale(model2, 2.0, input)
|
|
|
|
|
|
+ if model2 then
|
|
|
|
+ model2_output = reconstruct.scale(model2, 2.0, input)
|
|
|
|
+ end
|
|
end
|
|
end
|
|
if opt.color == "y" then
|
|
if opt.color == "y" then
|
|
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
|
|
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
|
|
model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
|
|
model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
|
|
- model2_mse = model2_mse + YMSE(ground_truth, model2_output)
|
|
|
|
- model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
|
|
|
|
|
|
+ if model2 then
|
|
|
|
+ model2_mse = model2_mse + YMSE(ground_truth, model2_output)
|
|
|
|
+ model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
|
|
|
|
+ end
|
|
elseif opt.color == "rgb" then
|
|
elseif opt.color == "rgb" then
|
|
model1_mse = model1_mse + MSE(ground_truth, model1_output)
|
|
model1_mse = model1_mse + MSE(ground_truth, model1_output)
|
|
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
|
|
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
|
|
- model2_mse = model2_mse + MSE(ground_truth, model2_output)
|
|
|
|
- model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
|
|
|
|
|
|
+ if model2 then
|
|
|
|
+ model2_mse = model2_mse + MSE(ground_truth, model2_output)
|
|
|
|
+ model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
|
|
|
|
+ end
|
|
else
|
|
else
|
|
error("Unknown color: " .. opt.color)
|
|
error("Unknown color: " .. opt.color)
|
|
end
|
|
end
|
|
-
|
|
|
|
- io.stdout:write(
|
|
|
|
- string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
|
|
|
- i, #x,
|
|
|
|
- model1_mse / i, model2_mse / i,
|
|
|
|
- model1_psnr / i, model2_psnr / i
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
|
|
+ if model2 then
|
|
|
|
+ io.stdout:write(
|
|
|
|
+ string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
|
|
|
+ i, #x,
|
|
|
|
+ model1_mse / i, model2_mse / i,
|
|
|
|
+ model1_psnr / i, model2_psnr / i
|
|
|
|
+ ))
|
|
|
|
+ else
|
|
|
|
+ io.stdout:write(
|
|
|
|
+ string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
|
|
|
|
+ i, #x,
|
|
|
|
+ model1_mse / i, model1_psnr / i
|
|
|
|
+ ))
|
|
|
|
+ end
|
|
io.stdout:flush()
|
|
io.stdout:flush()
|
|
end
|
|
end
|
|
io.stdout:write("\n")
|
|
io.stdout:write("\n")
|
|
@@ -122,16 +136,34 @@ local function load_data(test_dir)
|
|
end
|
|
end
|
|
return test_x
|
|
return test_x
|
|
end
|
|
end
|
|
-
|
|
|
|
|
|
+function load_model(filename)
|
|
|
|
+ return torch.load(filename, "ascii")
|
|
|
|
+end
|
|
print(opt)
|
|
print(opt)
|
|
if opt.method == "scale" then
|
|
if opt.method == "scale" then
|
|
- local model1 = torch.load(path.join(opt.model1_dir, "scale2.0x_model.t7"), "ascii")
|
|
|
|
- local model2 = torch.load(path.join(opt.model2_dir, "scale2.0x_model.t7"), "ascii")
|
|
|
|
|
|
+ local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
|
|
|
|
+ local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
|
|
|
|
+ local s1, model1 = pcall(load_model, f1)
|
|
|
|
+ local s2, model2 = pcall(load_model, f2)
|
|
|
|
+ if not s1 then
|
|
|
|
+ error("Load error: " .. f1)
|
|
|
|
+ end
|
|
|
|
+ if not s2 then
|
|
|
|
+ model2 = nil
|
|
|
|
+ end
|
|
local test_x = load_data(opt.dir)
|
|
local test_x = load_data(opt.dir)
|
|
benchmark(opt, test_x, transform_scale, model1, model2)
|
|
benchmark(opt, test_x, transform_scale, model1, model2)
|
|
elseif opt.method == "noise" then
|
|
elseif opt.method == "noise" then
|
|
- local model1 = torch.load(path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
|
|
|
|
- local model2 = torch.load(path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
|
|
|
|
|
|
+ local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
|
|
|
|
+ local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level))
|
|
|
|
+ local s1, model1 = pcall(load_model, f1)
|
|
|
|
+ local s2, model2 = pcall(load_model, f2)
|
|
|
|
+ if not s1 then
|
|
|
|
+ error("Load error: " .. f1)
|
|
|
|
+ end
|
|
|
|
+ if not s2 then
|
|
|
|
+ model2 = nil
|
|
|
|
+ end
|
|
local test_x = load_data(opt.dir)
|
|
local test_x = load_data(opt.dir)
|
|
benchmark(opt, test_x, transform_jpeg, model1, model2)
|
|
benchmark(opt, test_x, transform_jpeg, model1, model2)
|
|
end
|
|
end
|