瀏覽代碼

Don't run model2 benchmark when model2_dir is not specified

nagadomi 9 年之前
父節點
當前提交
7ac7923345
共有 1 個文件被更改,包括 53 次插入21 次删除
  1. 53 21
      tools/benchmark.lua

+ 53 - 21
tools/benchmark.lua

@@ -14,8 +14,8 @@ cmd:text("waifu2x-benchmark")
 cmd:text("Options:")
 
 cmd:option("-dir", "./data/test", 'test image directory')
-cmd:option("-model1_dir", "./models/anime_style_art", 'model1 directory')
-cmd:option("-model2_dir", "./models/anime_style_art_rgb", 'model2 directory')
+cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
+cmd:option("-model2_dir", "", 'model2 directory (optional)')
 cmd:option("-method", "scale", '(scale|noise)')
 cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)")
 cmd:option("-color", "rgb", '(rgb|y)')
@@ -83,32 +83,46 @@ local function benchmark(opt, x, input_func, model1, model2)
       t = sys.clock()
       if input:size(3) == ground_truth:size(3) then
 	 model1_output = reconstruct.image(model1, input)
-	 model2_output = reconstruct.image(model2, input)
+	 if model2 then
+	    model2_output = reconstruct.image(model2, input)
+	 end
       else
 	 model1_output = reconstruct.scale(model1, 2.0, input)
-	 model2_output = reconstruct.scale(model2, 2.0, input)
+	 if model2 then
+	    model2_output = reconstruct.scale(model2, 2.0, input)
+	 end
       end
       if opt.color == "y" then
 	 model1_mse = model1_mse + YMSE(ground_truth, model1_output)
 	 model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
-	 model2_mse = model2_mse + YMSE(ground_truth, model2_output)
-	 model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
+	 if model2 then
+	    model2_mse = model2_mse + YMSE(ground_truth, model2_output)
+	    model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
+	 end
       elseif opt.color == "rgb" then
 	 model1_mse = model1_mse + MSE(ground_truth, model1_output)
 	 model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
-	 model2_mse = model2_mse + MSE(ground_truth, model2_output)
-	 model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
+	 if model2 then
+	    model2_mse = model2_mse + MSE(ground_truth, model2_output)
+	    model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
+	 end
       else
 	 error("Unknown color: " .. opt.color)
       end
-      
-      io.stdout:write(
-	 string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
-		       i, #x,
-		       model1_mse / i, model2_mse / i,
-		       model1_psnr / i, model2_psnr / i
-	 )
-      )
+      if model2 then
+	 io.stdout:write(
+	    string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
+			  i, #x,
+			  model1_mse / i, model2_mse / i,
+			  model1_psnr / i, model2_psnr / i
+	    ))
+      else
+	 io.stdout:write(
+	    string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
+			  i, #x,
+			  model1_mse / i, model1_psnr / i
+	    ))
+      end
       io.stdout:flush()
    end
    io.stdout:write("\n")
@@ -122,16 +136,34 @@ local function load_data(test_dir)
    end
    return test_x
 end
-
+function load_model(filename)
+   return torch.load(filename, "ascii")
+end
 print(opt)
 if opt.method == "scale" then
-   local model1 = torch.load(path.join(opt.model1_dir, "scale2.0x_model.t7"), "ascii")
-   local model2 = torch.load(path.join(opt.model2_dir, "scale2.0x_model.t7"), "ascii")
+   local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
+   local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
+   local s1, model1 = pcall(load_model, f1)
+   local s2, model2 = pcall(load_model, f2)
+   if not s1 then
+      error("Load error: " .. f1)
+   end
+   if not s2 then
+      model2 = nil
+   end
    local test_x = load_data(opt.dir)
    benchmark(opt, test_x, transform_scale, model1, model2)
 elseif opt.method == "noise" then
-   local model1 = torch.load(path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
-   local model2 = torch.load(path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
+   local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
+   local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level))
+   local s1, model1 = pcall(load_model, f1)
+   local s2, model2 = pcall(load_model, f2)
+   if not s1 then
+      error("Load error: " .. f1)
+   end
+   if not s2 then
+      model2 = nil
+   end
    local test_x = load_data(opt.dir)
    benchmark(opt, test_x, transform_jpeg, model1, model2)
 end