浏览代码

Add -crop_size and -batch_size option to tools/benchmark.lua. Fix a bug in tta mode.

nagadomi 9 年之前
父节点
当前提交
c16d0a07a2
共有 1 个文件被更改,包括 16 次插入7 次删除
  1. 16 7
      tools/benchmark.lua

+ 16 - 7
tools/benchmark.lua

@@ -34,7 +34,10 @@ cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Ca
 cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt')
 cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt')
 cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option')
 cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option')
 cmd:option("-thread", -1, 'number of CPU threads')
 cmd:option("-thread", -1, 'number of CPU threads')
-cmd:option("-tta", 0, 'tta')
+cmd:option("-tta", 0, 'use tta')
+cmd:option("-tta_level", 8, 'tta level')
+cmd:option("-crop_size", 128, 'patch size per process')
+cmd:option("-batch_size", 1, 'batch_size')
 
 
 local function to_bool(settings, name)
 local function to_bool(settings, name)
    if settings[name] == 1 then
    if settings[name] == 1 then
@@ -136,8 +139,14 @@ local function benchmark(opt, x, input_func, model1, model2)
    local scale_f = reconstruct.scale
    local scale_f = reconstruct.scale
    local image_f = reconstruct.image
    local image_f = reconstruct.image
    if opt.tta then
    if opt.tta then
-      scale_f = reconstruct.scale_tta
-      image_f = reconstruct.image_tta
+      scale_f = function(model, scale, x, block_size, batch_size)
+	 return reconstruct.scale_tta(model, opt.tta_level,
+				      scale, x, block_size, batch_size)
+      end
+      image_f = function(model, x, block_size, batch_size)
+	 return reconstruct.image_tta(model, opt.tta_level,
+				      x, block_size, batch_size)
+      end
    end
    end
    
    
    for i = 1, #x do
    for i = 1, #x do
@@ -149,14 +158,14 @@ local function benchmark(opt, x, input_func, model1, model2)
       input = input_func(ground_truth, opt)
       input = input_func(ground_truth, opt)
       t = sys.clock()
       t = sys.clock()
       if input:size(3) == ground_truth:size(3) then
       if input:size(3) == ground_truth:size(3) then
-	 model1_output = image_f(model1, input)
+	 model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
 	 if model2 then
 	 if model2 then
-	    model2_output = image_f(model2, input)
+	    model2_output = image_f(model2, input, opt.crop_size, opt.batch_size)
 	 end
 	 end
       else
       else
-	 model1_output = scale_f(model1, 2.0, input)
+	 model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
 	 if model2 then
 	 if model2 then
-	    model2_output = scale_f(model2, 2.0, input)
+	    model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
 	 end
 	 end
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
       end
       end