|
@@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
|
|
|
cmd:option("-x_file", "", 'input image for user method')
|
|
|
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
|
|
|
cmd:option("-border", 0, 'border px that will removed')
|
|
|
+cmd:option("-metric", "", '(jaccard)')
|
|
|
|
|
|
local function to_bool(settings, name)
|
|
|
if settings[name] == 1 then
|
|
@@ -198,8 +199,34 @@ local function remove_border(x, border)
|
|
|
x:size(3) - border,
|
|
|
x:size(2) - border)
|
|
|
end
|
|
|
+local function create_metric(metric)
|
|
|
+ if metric and metric:len() > 0 then
|
|
|
+ if metric == "jaccard" then
|
|
|
+ return {
|
|
|
+ name = "jaccard",
|
|
|
+ func = function (a, b)
|
|
|
+ local ga = iproc.rgb2y(a)
|
|
|
+ local gb = iproc.rgb2y(b)
|
|
|
+ local ba = torch.Tensor():resizeAs(ga)
|
|
|
+ local bb = torch.Tensor():resizeAs(gb)
|
|
|
+ ba:zero()
|
|
|
+ bb:zero()
|
|
|
+ ba[torch.gt(ga, 0.5)] = 1.0
|
|
|
+ bb[torch.gt(gb, 0.5)] = 1.0
|
|
|
+ local num_a = ba:sum()
|
|
|
+ local num_b = bb:sum()
|
|
|
+ local a_and_b = ba:cmul(bb):sum()
|
|
|
+ return (a_and_b / (num_a + num_b - a_and_b))
|
|
|
+ end}
|
|
|
+ else
|
|
|
+ error("unknown metric: " .. metric)
|
|
|
+ end
|
|
|
+ else
|
|
|
+ return nil
|
|
|
+ end
|
|
|
+end
|
|
|
local function benchmark(opt, x, model1, model2)
|
|
|
- local mse1, mse2
|
|
|
+ local mse1, mse2, am1, am2
|
|
|
local won = {0, 0}
|
|
|
local model1_mse = 0
|
|
|
local model2_mse = 0
|
|
@@ -212,6 +239,13 @@ local function benchmark(opt, x, model1, model2)
|
|
|
local scale_f = reconstruct.scale
|
|
|
local image_f = reconstruct.image
|
|
|
local detail_fp = nil
|
|
|
+ local am = nil
|
|
|
+ local model1_am = 0
|
|
|
+ local model2_am = 0
|
|
|
+
|
|
|
+ if opt.method == "user" or opt.method == "diff" then
|
|
|
+ am = create_metric(opt.metric)
|
|
|
+ end
|
|
|
if opt.save_info then
|
|
|
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
|
|
|
end
|
|
@@ -401,32 +435,57 @@ local function benchmark(opt, x, model1, model2)
|
|
|
ground_truth = remove_border(ground_truth, opt.border)
|
|
|
model1_output = remove_border(model1_output, opt.border)
|
|
|
end
|
|
|
- mse1 = MSE(ground_truth, model1_output, opt.color)
|
|
|
- model1_mse = model1_mse + mse1
|
|
|
- model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
|
|
-
|
|
|
+ if am then
|
|
|
+ am1 = am.func(ground_truth, model1_output)
|
|
|
+ model1_am = model1_am + am1
|
|
|
+ else
|
|
|
+ mse1 = MSE(ground_truth, model1_output, opt.color)
|
|
|
+ model1_mse = model1_mse + mse1
|
|
|
+ model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
|
|
+ end
|
|
|
local won_model = 1
|
|
|
if model2 then
|
|
|
if opt.border > 0 then
|
|
|
model2_output = remove_border(model2_output, opt.border)
|
|
|
end
|
|
|
- mse2 = MSE(ground_truth, model2_output, opt.color)
|
|
|
- model2_mse = model2_mse + mse2
|
|
|
- model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
|
|
-
|
|
|
- if mse1 < mse2 then
|
|
|
- won[1] = won[1] + 1
|
|
|
- elseif mse1 > mse2 then
|
|
|
- won[2] = won[2] + 1
|
|
|
- won_model = 2
|
|
|
+ if am then
|
|
|
+ am2 = am.func(ground_truth, model2_output)
|
|
|
+ model2_am = model2_am + am2
|
|
|
+ else
|
|
|
+ mse2 = MSE(ground_truth, model2_output, opt.color)
|
|
|
+ model2_mse = model2_mse + mse2
|
|
|
+ model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
|
|
+ end
|
|
|
+ if am then
|
|
|
+ if am1 < am2 then
|
|
|
+ won[1] = won[1] + 1
|
|
|
+ elseif am1 > am2 then
|
|
|
+ won[2] = won[2] + 1
|
|
|
+ won_model = 2
|
|
|
+ end
|
|
|
+ else
|
|
|
+ if mse1 < mse2 then
|
|
|
+ won[1] = won[1] + 1
|
|
|
+ elseif mse1 > mse2 then
|
|
|
+ won[2] = won[2] + 1
|
|
|
+ won_model = 2
|
|
|
+ end
|
|
|
end
|
|
|
if detail_fp then
|
|
|
- detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
|
|
- MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
|
|
+ if am then
|
|
|
+ detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
|
|
|
+ else
|
|
|
+ detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
|
|
+ MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
|
|
+ end
|
|
|
end
|
|
|
else
|
|
|
if detail_fp then
|
|
|
- detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
|
|
+ if am then
|
|
|
+ detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
|
|
|
+ else
|
|
|
+ detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
|
|
+ end
|
|
|
end
|
|
|
end
|
|
|
if baseline_output then
|
|
@@ -450,46 +509,65 @@ local function benchmark(opt, x, model1, model2)
|
|
|
end
|
|
|
end
|
|
|
if opt.show_progress or i == #x then
|
|
|
- if model2 then
|
|
|
- if baseline_output then
|
|
|
+ if am then
|
|
|
+ if model2 then
|
|
|
io.stdout:write(
|
|
|
- string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
|
|
+ string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
|
|
|
i, #x,
|
|
|
model1_time,
|
|
|
model2_time,
|
|
|
- math.sqrt(baseline_mse / i),
|
|
|
- math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
|
|
- baseline_psnr / i,
|
|
|
- model1_psnr / i, model2_psnr / i,
|
|
|
- won[1], won[2]
|
|
|
- ))
|
|
|
+ am.name, model1_am / i, am.name, model2_am / i
|
|
|
+ ))
|
|
|
else
|
|
|
io.stdout:write(
|
|
|
- string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
|
|
+ string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
|
|
|
i, #x,
|
|
|
model1_time,
|
|
|
- model2_time,
|
|
|
- math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
|
|
- model1_psnr / i, model2_psnr / i,
|
|
|
- won[1], won[2]
|
|
|
- ))
|
|
|
+ am.name, model1_am / i
|
|
|
+ ))
|
|
|
end
|
|
|
else
|
|
|
- if baseline_output then
|
|
|
- io.stdout:write(
|
|
|
- string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
|
|
- i, #x,
|
|
|
- model1_time,
|
|
|
- math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
|
|
- baseline_psnr / i, model1_psnr / i
|
|
|
+ if model2 then
|
|
|
+ if baseline_output then
|
|
|
+ io.stdout:write(
|
|
|
+ string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
|
|
+ i, #x,
|
|
|
+ model1_time,
|
|
|
+ model2_time,
|
|
|
+ math.sqrt(baseline_mse / i),
|
|
|
+ math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
|
|
+ baseline_psnr / i,
|
|
|
+ model1_psnr / i, model2_psnr / i,
|
|
|
+ won[1], won[2]
|
|
|
+ ))
|
|
|
+ else
|
|
|
+ io.stdout:write(
|
|
|
+ string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
|
|
+ i, #x,
|
|
|
+ model1_time,
|
|
|
+ model2_time,
|
|
|
+ math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
|
|
+ model1_psnr / i, model2_psnr / i,
|
|
|
+ won[1], won[2]
|
|
|
))
|
|
|
+ end
|
|
|
else
|
|
|
- io.stdout:write(
|
|
|
- string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
|
|
- i, #x,
|
|
|
- model1_time,
|
|
|
- math.sqrt(model1_mse / i), model1_psnr / i
|
|
|
+ if baseline_output then
|
|
|
+ io.stdout:write(
|
|
|
+ string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
|
|
+ i, #x,
|
|
|
+ model1_time,
|
|
|
+ math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
|
|
+ baseline_psnr / i, model1_psnr / i
|
|
|
))
|
|
|
+ else
|
|
|
+ io.stdout:write(
|
|
|
+ string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
|
|
+ i, #x,
|
|
|
+ model1_time,
|
|
|
+ math.sqrt(model1_mse / i), model1_psnr / i
|
|
|
+ ))
|
|
|
+ end
|
|
|
end
|
|
|
end
|
|
|
io.stdout:flush()
|
|
@@ -510,6 +588,14 @@ local function benchmark(opt, x, model1, model2)
|
|
|
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
|
|
|
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
|
|
end
|
|
|
+ if model1_am > 0 then
|
|
|
+ fp:write(string.format("model1 : %s = %.3f, evaluation time = %.3f\n",
|
|
|
+ math.sqrt(model1_am / #x), model1_time))
|
|
|
+ end
|
|
|
+ if model2_am > 0 then
|
|
|
+ fp:write(string.format("model2 : %s = %.3f, evaluation time = %.3f\n",
|
|
|
+ math.sqrt(model2_am / #x), model2_time))
|
|
|
+ end
|
|
|
fp:close()
|
|
|
if detail_fp then
|
|
|
detail_fp:close()
|