|
@@ -34,7 +34,10 @@ cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Ca
|
|
|
cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt')
|
|
|
cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option')
|
|
|
cmd:option("-thread", -1, 'number of CPU threads')
|
|
|
-cmd:option("-tta", 0, 'tta')
|
|
|
+cmd:option("-tta", 0, 'use tta')
|
|
|
+cmd:option("-tta_level", 8, 'tta level')
|
|
|
+cmd:option("-crop_size", 128, 'patch size per process')
|
|
|
+cmd:option("-batch_size", 1, 'batch_size')
|
|
|
|
|
|
local function to_bool(settings, name)
|
|
|
if settings[name] == 1 then
|
|
@@ -136,8 +139,14 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|
|
local scale_f = reconstruct.scale
|
|
|
local image_f = reconstruct.image
|
|
|
if opt.tta then
|
|
|
- scale_f = reconstruct.scale_tta
|
|
|
- image_f = reconstruct.image_tta
|
|
|
+ scale_f = function(model, scale, x, block_size, batch_size)
|
|
|
+ return reconstruct.scale_tta(model, opt.tta_level,
|
|
|
+ scale, x, block_size, batch_size)
|
|
|
+ end
|
|
|
+ image_f = function(model, x, block_size, batch_size)
|
|
|
+ return reconstruct.image_tta(model, opt.tta_level,
|
|
|
+ x, block_size, batch_size)
|
|
|
+ end
|
|
|
end
|
|
|
|
|
|
for i = 1, #x do
|
|
@@ -149,14 +158,14 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|
|
input = input_func(ground_truth, opt)
|
|
|
t = sys.clock()
|
|
|
if input:size(3) == ground_truth:size(3) then
|
|
|
- model1_output = image_f(model1, input)
|
|
|
+ model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
|
|
|
if model2 then
|
|
|
- model2_output = image_f(model2, input)
|
|
|
+ model2_output = image_f(model2, input, opt.crop_size, opt.batch_size)
|
|
|
end
|
|
|
else
|
|
|
- model1_output = scale_f(model1, 2.0, input)
|
|
|
+ model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
|
|
|
if model2 then
|
|
|
- model2_output = scale_f(model2, 2.0, input)
|
|
|
+ model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
|
|
|
end
|
|
|
baseline_output = baseline_scale(input, opt.baseline_filter)
|
|
|
end
|