Selaa lähdekoodia

benchmark time

nagadomi 9 vuotta sitten
vanhempi
commit
67d36a1220
1 muutettua tiedostoa jossa 43 lisäystä ja 16 poistoa
  1. 43 16
      tools/benchmark.lua

+ 43 - 16
tools/benchmark.lua

@@ -105,6 +105,10 @@ local function PSNR(x1, x2, color)
    local mse = math.max(MSE(x1, x2, color), 1)
    return 10 * math.log10((255.0 * 255.0) / mse)
 end
+local function MSE2PSNR(mse)
+   return 10 * math.log10((255.0 * 255.0) / mse)
+end
+
 local function transform_jpeg(x, opt)
    for i = 1, opt.jpeg_times do
       jpeg = gm.Image(x, "RGB", "DHW")
@@ -146,12 +150,15 @@ local function transform_scale_jpeg(x, opt)
 end
 
 local function benchmark(opt, x, input_func, model1, model2)
+   local mse
    local model1_mse = 0
    local model2_mse = 0
    local baseline_mse = 0
    local model1_psnr = 0
    local model2_psnr = 0
    local baseline_psnr = 0
+   local model1_time = 0
+   local model2_time = 0
    local scale_f = reconstruct.scale
    local image_f = reconstruct.image
    if opt.tta then
@@ -168,24 +175,31 @@ local function benchmark(opt, x, input_func, model1, model2)
    for i = 1, #x do
       local ground_truth = x[i].image
       local basename = x[i].basename
-      
       local input, model1_output, model2_output, baseline_output
 
       input = input_func(ground_truth, opt)
-      t = sys.clock()
       if opt.method == "scale" then
+	 t = sys.clock()
 	 model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
+	 model1_time = model1_time + (sys.clock() - t)
 	 if model2 then
+	    t = sys.clock()
 	    model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
+	    model2_time = model2_time + (sys.clock() - t)
 	 end
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
       elseif opt.method == "noise" then
+	 t = sys.clock()
 	 model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
+	 model1_time = model1_time + (sys.clock() - t)
 	 if model2 then
+	    t = sys.clock()
 	    model2_output = image_f(model2, input, opt.crop_size, opt.batch_size)
+	    model2_time = model2_time + (sys.clock() - t)
 	 end
 	 baseline_output = input
       elseif opt.method == "noise_scale" then
+	 t = sys.clock()
 	 if model1.noise_scale_model then
 	    model1_output = scale_f(model1.noise_scale_model, 2.0,
 				    input, opt.crop_size, opt.batch_size)
@@ -198,7 +212,10 @@ local function benchmark(opt, x, input_func, model1, model2)
 	    model1_output = scale_f(model1.scale_model, 2.0, model1_output,
 				    opt.crop_size, opt.batch_size)
 	 end
+	 model1_time = model1_time + (sys.clock() - t)
+
 	 if model2 then
+	    t = sys.clock()
 	    if model2.noise_scale_model then
 	       model2_output = scale_f(model2.noise_scale_model, 2.0,
 				       input, opt.crop_size, opt.batch_size)
@@ -212,18 +229,22 @@ local function benchmark(opt, x, input_func, model1, model2)
 	       model2_output = scale_f(model2.scale_model, 2.0, model2_output,
 				       opt.crop_size, opt.batch_size)
 	    end
+	    model2_time = model2_time + (sys.clock() - t)
 	 end
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
       end
-      model1_mse = model1_mse + MSE(ground_truth, model1_output, opt.color)
-      model1_psnr = model1_psnr + PSNR(ground_truth, model1_output, opt.color)
+      mse = MSE(ground_truth, model1_output, opt.color)
+      model1_mse = model1_mse + mse
+      model1_psnr = model1_psnr + MSE2PSNR(mse)
       if model2 then
-	 model2_mse = model2_mse + MSE(ground_truth, model2_output, opt.color)
-	 model2_psnr = model2_psnr + PSNR(ground_truth, model2_output, opt.color)
+	 mse = MSE(ground_truth, model2_output, opt.color)
+	 model2_mse = model2_mse + mse
+	 model2_psnr = model2_psnr + MSE2PSNR(mse)
       end
       if baseline_output then
-	 baseline_mse = baseline_mse + MSE(ground_truth, baseline_output, opt.color)
-	 baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output, opt.color)
+	 mse = MSE(ground_truth, baseline_output, opt.color)
+	 baseline_mse = baseline_mse + mse
+	 baseline_psnr = baseline_psnr + MSE2PSNR(mse)
       end
       if opt.save_image then
 	 if opt.save_baseline_image and baseline_output then
@@ -243,8 +264,10 @@ local function benchmark(opt, x, input_func, model1, model2)
 	 if model2 then
 	    if baseline_output then
 	       io.stdout:write(
-		  string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
 				i, #x,
+				model1_time,
+				model2_time,
 				math.sqrt(baseline_mse / i),
 				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
 				baseline_psnr / i,
@@ -252,8 +275,10 @@ local function benchmark(opt, x, input_func, model1, model2)
 		  ))
 	    else
 	       io.stdout:write(
-		  string.format("%d/%d; model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
 				i, #x,
+				model1_time,
+				model2_time,
 				math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
 				model1_psnr / i, model2_psnr / i
 		  ))
@@ -261,15 +286,17 @@ local function benchmark(opt, x, input_func, model1, model2)
 	 else
 	    if baseline_output then
 	       io.stdout:write(
-		  string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r",
 				i, #x,
+				model1_time,
 				math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
 				baseline_psnr / i, model1_psnr / i
 		  ))
 	    else
 	       io.stdout:write(
-		  string.format("%d/%d; model1_rmse=%f, model1_psnr=%f \r",
+		  string.format("%d/%d; model1_time=%.2f, model1_rmse=%f, model1_psnr=%f \r",
 				i, #x,
+				model1_time,
 				math.sqrt(model1_mse / i), model1_psnr / i
 		  ))
 	    end
@@ -285,12 +312,12 @@ local function benchmark(opt, x, input_func, model1, model2)
 				math.sqrt(baseline_mse / #x), baseline_psnr / #x))
       end
       if model1_psnr > 0 then
-	 fp:write(string.format("model1  : RMSE = %.3f, PSNR = %.3f\n",
-				math.sqrt(model1_mse / #x), model1_psnr / #x))
+	 fp:write(string.format("model1  : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
+				math.sqrt(model1_mse / #x), model1_psnr / #x, model1_time))
       end
       if model2_psnr > 0 then
-	 fp:write(string.format("model2  : RMSE = %.3f, PSNR = %.3f\n",
-				math.sqrt(model2_mse / #x), model2_psnr / #x))
+	 fp:write(string.format("model2  : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
+				math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
       end
       fp:close()
    end