Explorar o código

show baseline

nagadomi %!s(int64=9) %!d(string=hai) anos
pai
achega
c47df93505
Modificáronse 1 ficheiros con 41 adicións e 10 borrados
  1. 41 10
      tools/benchmark.lua

+ 41 - 10
tools/benchmark.lua

@@ -69,12 +69,14 @@ end
 local function benchmark(opt, x, input_func, model1, model2)
    local model1_mse = 0
    local model2_mse = 0
+   local baseline_mse = 0
    local model1_psnr = 0
    local model2_psnr = 0
+   local baseline_psnr = 0
    
    for i = 1, #x do
       local ground_truth = x[i]
-      local input, model1_output, model2_output
+      local input, model1_output, model2_output, baseline_output
 
       input = input_func(ground_truth, opt)
       input = input:float():div(255)
@@ -91,6 +93,7 @@ local function benchmark(opt, x, input_func, model1, model2)
 	 if model2 then
 	    model2_output = reconstruct.scale(model2, 2.0, input)
 	 end
+	 baseline_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, opt.filter)
       end
       if opt.color == "y" then
 	 model1_mse = model1_mse + YMSE(ground_truth, model1_output)
@@ -99,6 +102,10 @@ local function benchmark(opt, x, input_func, model1, model2)
 	    model2_mse = model2_mse + YMSE(ground_truth, model2_output)
 	    model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
 	 end
+	 if baseline_output then
+	    baseline_mse = baseline_mse + YMSE(ground_truth, baseline_output)
+	    baseline_psnr = baseline_psnr + YPSNR(ground_truth, baseline_output)
+	 end
       elseif opt.color == "rgb" then
 	 model1_mse = model1_mse + MSE(ground_truth, model1_output)
 	 model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
@@ -106,22 +113,46 @@ local function benchmark(opt, x, input_func, model1, model2)
 	    model2_mse = model2_mse + MSE(ground_truth, model2_output)
 	    model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
 	 end
+	 if baseline_output then
+	    baseline_mse = baseline_mse + MSE(ground_truth, baseline_output)
+	    baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output)
+	 end
       else
 	 error("Unknown color: " .. opt.color)
       end
       if model2 then
-	 io.stdout:write(
-	    string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
-			  i, #x,
-			  model1_mse / i, model2_mse / i,
-			  model1_psnr / i, model2_psnr / i
+	 if baseline_output then
+	    io.stdout:write(
+	       string.format("%d/%d; baseline_mse=%f, model1_mse=%f, model2_mse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
+			     i, #x,
+			     baseline_mse / i,
+			     model1_mse / i, model2_mse / i,
+			     baseline_psnr / i,
+			     model1_psnr / i, model2_psnr / i
+	    ))
+	 else
+	    io.stdout:write(
+	       string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
+			     i, #x,
+			     model1_mse / i, model2_mse / i,
+			     model1_psnr / i, model2_psnr / i
 	    ))
+	 end
       else
-	 io.stdout:write(
-	    string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
-			  i, #x,
-			  model1_mse / i, model1_psnr / i
+	 if baseline_output then
+	    io.stdout:write(
+	       string.format("%d/%d; baseline_mse=%f, model1_mse=%f, baseline_psnr=%f, model1_psnr=%f \r",
+			     i, #x,
+			     baseline_mse / i, model1_mse / i,
+			     baseline_psnr / i, model1_psnr / i
+	    ))
+	 else
+	    io.stdout:write(
+	       string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
+			     i, #x,
+			     model1_mse / i, model1_psnr / i
 	    ))
+	 end
       end
       io.stdout:flush()
    end