Forráskód Böngészése

Add -tta and -resize_blur option to benchmark

nagadomi 9 éve
szülő
commit
5e222a3981
1 módosított fájl, 15 hozzáadás és 6 törlés
  1. 15 6
      tools/benchmark.lua

+ 15 - 6
tools/benchmark.lua

@@ -19,6 +19,7 @@ cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
 cmd:option("-model2_dir", "", 'model2 directory (optional)')
 cmd:option("-method", "scale", '(scale|noise)')
 cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))")
+cmd:option("-resize_blur", 1.0, 'blur parameter for resize')
 cmd:option("-color", "y", '(rgb|y)')
 cmd:option("-noise_level", 1, 'model noise level')
 cmd:option("-jpeg_quality", 75, 'jpeg quality')
@@ -34,6 +35,7 @@ cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Ca
 cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt')
 cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option')
 cmd:option("-thread", -1, 'number of CPU threads')
+cmd:option("-tta", 0, 'tta')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -50,6 +52,7 @@ if cudnn then
 end
 to_bool(opt, "gamma_correction")
 to_bool(opt, "save_all")
+to_bool(opt, "tta")
 if opt.save_all then
    opt.save_image = true
    opt.save_info = true
@@ -123,12 +126,12 @@ local function transform_scale(x, opt)
       return iproc.scale_with_gamma22(x,
 			 x:size(3) * 0.5,
 			 x:size(2) * 0.5,
-			 opt.filter)
+			 opt.filter, opt.resize_blur)
    else
       return iproc.scale(x,
 			 x:size(3) * 0.5,
 			 x:size(2) * 0.5,
-			 opt.filter)
+			 opt.filter, opt.resize_blur)
    end
 end
 
@@ -139,6 +142,12 @@ local function benchmark(opt, x, input_func, model1, model2)
    local model1_psnr = 0
    local model2_psnr = 0
    local baseline_psnr = 0
+   local scale_f = reconstruct.scale
+   local image_f = reconstruct.image
+   if opt.tta then
+      scale_f = reconstruct.scale_tta
+      image_f = reconstruct.image_tta
+   end
    
    for i = 1, #x do
       local ground_truth = x[i].image
@@ -149,14 +158,14 @@ local function benchmark(opt, x, input_func, model1, model2)
       input = input_func(ground_truth, opt)
       t = sys.clock()
       if input:size(3) == ground_truth:size(3) then
-	 model1_output = reconstruct.image(model1, input)
+	 model1_output = image_f(model1, input)
 	 if model2 then
-	    model2_output = reconstruct.image(model2, input)
+	    model2_output = image_f(model2, input)
 	 end
       else
-	 model1_output = reconstruct.scale(model1, 2.0, input)
+	 model1_output = scale_f(model1, 2.0, input)
 	 if model2 then
-	    model2_output = reconstruct.scale(model2, 2.0, input)
+	    model2_output = scale_f(model2, 2.0, input)
 	 end
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
       end