瀏覽代碼

Add count of won in benchmark

nagadomi 8 年之前
父節點
當前提交
51e189b6c0
共有 1 個文件被更改,包括 20 次插入11 次删除
  1. 20 11
      tools/benchmark.lua

+ 20 - 11
tools/benchmark.lua

@@ -181,7 +181,8 @@ local function transform_scale_jpeg(x, opt)
 end
 
 local function benchmark(opt, x, model1, model2)
-   local mse
+   local mse1, mse2
+   local won = {0, 0}
    local model1_mse = 0
    local model2_mse = 0
    local baseline_mse = 0
@@ -351,13 +352,19 @@ local function benchmark(opt, x, model1, model2)
 	 ground_truth = x[i].y
 	 model1_output = input
       end
-      mse = MSE(ground_truth, model1_output, opt.color)
-      model1_mse = model1_mse + mse
-      model1_psnr = model1_psnr + MSE2PSNR(mse)
+      mse1 = MSE(ground_truth, model1_output, opt.color)
+      model1_mse = model1_mse + mse1
+      model1_psnr = model1_psnr + MSE2PSNR(mse1)
       if model2 then
-	 mse = MSE(ground_truth, model2_output, opt.color)
-	 model2_mse = model2_mse + mse
-	 model2_psnr = model2_psnr + MSE2PSNR(mse)
+	 mse2 = MSE(ground_truth, model2_output, opt.color)
+	 model2_mse = model2_mse + mse2
+	 model2_psnr = model2_psnr + MSE2PSNR(mse2)
+
+	 if mse1 < mse2 then
+	    won[1] = won[1] + 1
+	 elseif mse1 > mse2 then
+	    won[2] = won[2] + 1
+	 end
       end
       if baseline_output then
 	 mse = MSE(ground_truth, baseline_output, opt.color)
@@ -382,23 +389,25 @@ local function benchmark(opt, x, model1, model2)
 	 if model2 then
 	    if baseline_output then
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f, model1_won=%d, model2_won=%d \r",
 				i, #x,
 				model1_time,
 				model2_time,
 				math.sqrt(baseline_mse / i),
 				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
 				baseline_psnr / i,
-				model1_psnr / i, model2_psnr / i
+				model1_psnr / i, model2_psnr / i,
+				won[1], won[2]
 		  ))
 	    else
 	       io.stdout:write(
-		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f, model1_own=%d, model2_won=%d \r",
 				i, #x,
 				model1_time,
 				model2_time,
 				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
-				model1_psnr / i, model2_psnr / i
+				model1_psnr / i, model2_psnr / i,
+				won[1], won[2]
 		  ))
 	    end
 	 else