ソースを参照

add support for jaccard in benchmark

nagadomi 8 年 前
コミット
d8b7df4505
1 ファイル変更130 行追加44 行削除
  1. 130 44
      tools/benchmark.lua

+ 130 - 44
tools/benchmark.lua

@@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
 cmd:option("-x_file", "", 'input image for user method')
 cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
 cmd:option("-border", 0, 'border px that will removed')
+cmd:option("-metric", "", '(jaccard)')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -198,8 +199,34 @@ local function remove_border(x, border)
 		     x:size(3) - border,
 		     x:size(2) - border)
 end
+local function create_metric(metric)
+   if metric and metric:len() > 0 then
+      if metric == "jaccard" then
+	 return {
+	    name = "jaccard", 
+	    func = function (a, b) 
+	       local ga = iproc.rgb2y(a)
+	       local gb = iproc.rgb2y(b)
+	       local ba = torch.Tensor():resizeAs(ga)
+	       local bb = torch.Tensor():resizeAs(gb)
+	       ba:zero()
+	       bb:zero()
+	       ba[torch.gt(ga, 0.5)] = 1.0
+	       bb[torch.gt(gb, 0.5)] = 1.0
+	       local num_a = ba:sum()
+	       local num_b = bb:sum()
+	       local a_and_b  = ba:cmul(bb):sum()
+	       return (a_and_b / (num_a + num_b - a_and_b))
+	 end}
+      else
+	 error("unknown metric: " .. metric)
+      end
+   else
+      return nil
+   end
+end
 local function benchmark(opt, x, model1, model2)
-   local mse1, mse2
+   local mse1, mse2, am1, am2
    local won = {0, 0}
    local model1_mse = 0
    local model2_mse = 0
@@ -212,6 +239,13 @@ local function benchmark(opt, x, model1, model2)
    local scale_f = reconstruct.scale
    local image_f = reconstruct.image
    local detail_fp = nil
+   local am = nil
+   local model1_am = 0
+   local model2_am = 0
+
+   if opt.method == "user" or opt.method == "diff" then
+      am = create_metric(opt.metric)
+   end
    if opt.save_info then
       detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
    end
@@ -401,32 +435,57 @@ local function benchmark(opt, x, model1, model2)
 	 ground_truth = remove_border(ground_truth, opt.border)
 	 model1_output = remove_border(model1_output, opt.border)
       end
-      mse1 = MSE(ground_truth, model1_output, opt.color)
-      model1_mse = model1_mse + mse1
-      model1_psnr = model1_psnr + MSE2PSNR(mse1)
-
+      if am then
+	 am1 = am.func(ground_truth, model1_output)
+	 model1_am = model1_am + am1
+      else
+	 mse1 = MSE(ground_truth, model1_output, opt.color)
+	 model1_mse = model1_mse + mse1
+	 model1_psnr = model1_psnr + MSE2PSNR(mse1)
+      end
       local won_model = 1
       if model2 then
 	 if opt.border > 0 then
 	    model2_output = remove_border(model2_output, opt.border)
 	 end
-	 mse2 = MSE(ground_truth, model2_output, opt.color)
-	 model2_mse = model2_mse + mse2
-	 model2_psnr = model2_psnr + MSE2PSNR(mse2)
-
-	 if mse1 < mse2 then
-	    won[1] = won[1] + 1
-	 elseif mse1 > mse2 then
-	    won[2] = won[2] + 1
-	    won_model = 2
+	 if am then
+	    am2 = am.func(ground_truth, model2_output)
+	    model2_am = model2_am + am2
+	 else
+	    mse2 = MSE(ground_truth, model2_output, opt.color)
+	    model2_mse = model2_mse + mse2
+	    model2_psnr = model2_psnr + MSE2PSNR(mse2)
+	 end
+	 if am then
+	    if am1 < am2 then
+	       won[1] = won[1] + 1
+	    elseif am1 > am2 then
+	       won[2] = won[2] + 1
+	       won_model = 2
+	    end
+	 else
+	    if mse1 < mse2 then
+	       won[1] = won[1] + 1
+	    elseif mse1 > mse2 then
+	       won[2] = won[2] + 1
+	       won_model = 2
+	    end
 	 end
 	 if detail_fp then
-	    detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
-					  MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
+	    if am then
+	       detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
+	    else
+	       detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
+					     MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
+	    end
 	 end
       else
 	 if detail_fp then
-	    detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
+	    if am then
+	       detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
+	    else
+	       detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
+	    end
 	 end
       end
       if baseline_output then
@@ -450,46 +509,65 @@ local function benchmark(opt, x, model1, model2)
 	 end
       end
       if opt.show_progress or i == #x then
-	 if model2 then
-	    if baseline_output then
+	 if am then
+	    if model2 then
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
+		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
 				i, #x,
 				model1_time,
 				model2_time,
-				math.sqrt(baseline_mse / i),
-				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
-				baseline_psnr / i,
-				model1_psnr / i, model2_psnr / i,
-				won[1], won[2]
-		  ))
+				am.name, model1_am / i, am.name, model2_am / i
+	       ))
 	    else
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
+		  string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
 				i, #x,
 				model1_time,
-				model2_time,
-				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
-				model1_psnr / i, model2_psnr / i,
-				won[1], won[2]
-		  ))
+				am.name, model1_am / i
+	       ))
 	    end
 	 else
-	    if baseline_output then
-	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
-				i, #x,
-				model1_time,
-				math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
-				baseline_psnr / i, model1_psnr / i
+	    if model2 then
+	       if baseline_output then
+		  io.stdout:write(
+		     string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
+				   i, #x,
+				   model1_time,
+				   model2_time,
+				   math.sqrt(baseline_mse / i),
+				   math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
+				   baseline_psnr / i,
+				   model1_psnr / i, model2_psnr / i,
+				   won[1], won[2]
+		  ))
+	       else
+		  io.stdout:write(
+		     string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
+				   i, #x,
+				   model1_time,
+				   model2_time,
+				   math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
+				   model1_psnr / i, model2_psnr / i,
+				   won[1], won[2]
 		  ))
+	       end
 	    else
-	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
-				i, #x,
-				model1_time,
-				math.sqrt(model1_mse / i), model1_psnr / i
+	       if baseline_output then
+		  io.stdout:write(
+		     string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
+				   i, #x,
+				   model1_time,
+				   math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
+				   baseline_psnr / i, model1_psnr / i
 		  ))
+	       else
+		  io.stdout:write(
+		     string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
+				   i, #x,
+				   model1_time,
+				   math.sqrt(model1_mse / i), model1_psnr / i
+		  ))
+	       end
 	    end
 	 end
 	 io.stdout:flush()
@@ -510,6 +588,14 @@ local function benchmark(opt, x, model1, model2)
 	 fp:write(string.format("model2  : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
 				math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
       end
+      if model1_am > 0 then
+	 fp:write(string.format("model1  : %s = %.3f, evaluation time = %.3f\n",
+				math.sqrt(model1_am / #x), model1_time))
+      end
+      if model2_am > 0 then
+	 fp:write(string.format("model2  : %s = %.3f, evaluation time = %.3f\n",
+				math.sqrt(model2_am / #x), model2_time))
+      end
       fp:close()
       if detail_fp then
 	 detail_fp:close()