Browse Source

more detailed log in benchmark

nagadomi 8 years ago
parent
commit
1ce5a8d038
1 changed files with 18 additions and 0 deletions
  1. 18 0
      tools/benchmark.lua

+ 18 - 0
tools/benchmark.lua

@@ -193,6 +193,10 @@ local function benchmark(opt, x, model1, model2)
    local model2_time = 0
    local model2_time = 0
    local scale_f = reconstruct.scale
    local scale_f = reconstruct.scale
    local image_f = reconstruct.image
    local image_f = reconstruct.image
+   local detail_fp = nil
+   if opt.save_info then
+      detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
+   end
    if opt.tta then
    if opt.tta then
       scale_f = function(model, scale, x, block_size, batch_size)
       scale_f = function(model, scale, x, block_size, batch_size)
 	 return reconstruct.scale_tta(model, opt.tta_level,
 	 return reconstruct.scale_tta(model, opt.tta_level,
@@ -355,6 +359,8 @@ local function benchmark(opt, x, model1, model2)
       mse1 = MSE(ground_truth, model1_output, opt.color)
       mse1 = MSE(ground_truth, model1_output, opt.color)
       model1_mse = model1_mse + mse1
       model1_mse = model1_mse + mse1
       model1_psnr = model1_psnr + MSE2PSNR(mse1)
       model1_psnr = model1_psnr + MSE2PSNR(mse1)
+
+      local won_model = 1
       if model2 then
       if model2 then
 	 mse2 = MSE(ground_truth, model2_output, opt.color)
 	 mse2 = MSE(ground_truth, model2_output, opt.color)
 	 model2_mse = model2_mse + mse2
 	 model2_mse = model2_mse + mse2
@@ -364,6 +370,15 @@ local function benchmark(opt, x, model1, model2)
 	    won[1] = won[1] + 1
 	    won[1] = won[1] + 1
 	 elseif mse1 > mse2 then
 	 elseif mse1 > mse2 then
 	    won[2] = won[2] + 1
 	    won[2] = won[2] + 1
+	    won_model = 2
+	 end
+	 if detail_fp then
+	    detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
+					  MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
+	 end
+      else
+	 if detail_fp then
+	    detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
 	 end
 	 end
       end
       end
       if baseline_output then
       if baseline_output then
@@ -447,6 +462,9 @@ local function benchmark(opt, x, model1, model2)
 				math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
 				math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
       end
       end
       fp:close()
       fp:close()
+      if detail_fp then
+	 detail_fp:close()
+      end
    end
    end
    io.stdout:write("\n")
    io.stdout:write("\n")
 end
 end