|
@@ -69,12 +69,14 @@ end
|
|
|
local function benchmark(opt, x, input_func, model1, model2)
|
|
|
local model1_mse = 0
|
|
|
local model2_mse = 0
|
|
|
+ local baseline_mse = 0
|
|
|
local model1_psnr = 0
|
|
|
local model2_psnr = 0
|
|
|
+ local baseline_psnr = 0
|
|
|
|
|
|
for i = 1, #x do
|
|
|
local ground_truth = x[i]
|
|
|
- local input, model1_output, model2_output
|
|
|
+ local input, model1_output, model2_output, baseline_output
|
|
|
|
|
|
input = input_func(ground_truth, opt)
|
|
|
input = input:float():div(255)
|
|
@@ -91,6 +93,7 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|
|
if model2 then
|
|
|
model2_output = reconstruct.scale(model2, 2.0, input)
|
|
|
end
|
|
|
+ baseline_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, opt.filter)
|
|
|
end
|
|
|
if opt.color == "y" then
|
|
|
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
|
|
@@ -99,6 +102,10 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|
|
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
|
|
|
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
|
|
|
end
|
|
|
+ if baseline_output then
|
|
|
+ baseline_mse = baseline_mse + YMSE(ground_truth, baseline_output)
|
|
|
+ baseline_psnr = baseline_psnr + YPSNR(ground_truth, baseline_output)
|
|
|
+ end
|
|
|
elseif opt.color == "rgb" then
|
|
|
model1_mse = model1_mse + MSE(ground_truth, model1_output)
|
|
|
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
|
|
@@ -106,22 +113,46 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|
|
model2_mse = model2_mse + MSE(ground_truth, model2_output)
|
|
|
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
|
|
|
end
|
|
|
+ if baseline_output then
|
|
|
+ baseline_mse = baseline_mse + MSE(ground_truth, baseline_output)
|
|
|
+ baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output)
|
|
|
+ end
|
|
|
else
|
|
|
error("Unknown color: " .. opt.color)
|
|
|
end
|
|
|
if model2 then
|
|
|
- io.stdout:write(
|
|
|
- string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
|
|
- i, #x,
|
|
|
- model1_mse / i, model2_mse / i,
|
|
|
- model1_psnr / i, model2_psnr / i
|
|
|
+ if baseline_output then
|
|
|
+ io.stdout:write(
|
|
|
+ string.format("%d/%d; baseline_mse=%f, model1_mse=%f, model2_mse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
|
|
|
+ i, #x,
|
|
|
+ baseline_mse / i,
|
|
|
+ model1_mse / i, model2_mse / i,
|
|
|
+ baseline_psnr / i,
|
|
|
+ model1_psnr / i, model2_psnr / i
|
|
|
+ ))
|
|
|
+ else
|
|
|
+ io.stdout:write(
|
|
|
+ string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
|
|
+ i, #x,
|
|
|
+ model1_mse / i, model2_mse / i,
|
|
|
+ model1_psnr / i, model2_psnr / i
|
|
|
))
|
|
|
+ end
|
|
|
else
|
|
|
- io.stdout:write(
|
|
|
- string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
|
|
|
- i, #x,
|
|
|
- model1_mse / i, model1_psnr / i
|
|
|
+ if baseline_output then
|
|
|
+ io.stdout:write(
|
|
|
+ string.format("%d/%d; baseline_mse=%f, model1_mse=%f, baseline_psnr=%f, model1_psnr=%f \r",
|
|
|
+ i, #x,
|
|
|
+ baseline_mse / i, model1_mse / i,
|
|
|
+ baseline_psnr / i, model1_psnr / i
|
|
|
+ ))
|
|
|
+ else
|
|
|
+ io.stdout:write(
|
|
|
+ string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
|
|
|
+ i, #x,
|
|
|
+ model1_mse / i, model1_psnr / i
|
|
|
))
|
|
|
+ end
|
|
|
end
|
|
|
io.stdout:flush()
|
|
|
end
|