瀏覽代碼

load float images directly

nagadomi 9 年之前
父節點
當前提交
e9187ecf1c
共有 1 個文件被更改,包括 8 次插入5 次删除
  1. 8 5
      tools/benchmark.lua

+ 8 - 5
tools/benchmark.lua

@@ -59,6 +59,12 @@ local function transform_jpeg(x, opt)
    end
    end
    return x
    return x
 end
 end
+local function baseline_scale(x, filter)
+   return iproc.scale(x,
+		      x:size(3) * 2.0,
+		      x:size(2) * 2.0,
+		      filter)
+end
 local function transform_scale(x, opt)
 local function transform_scale(x, opt)
    return iproc.scale(x,
    return iproc.scale(x,
 		      x:size(3) * 0.5,
 		      x:size(3) * 0.5,
@@ -79,9 +85,6 @@ local function benchmark(opt, x, input_func, model1, model2)
       local input, model1_output, model2_output, baseline_output
       local input, model1_output, model2_output, baseline_output
 
 
       input = input_func(ground_truth, opt)
       input = input_func(ground_truth, opt)
-      input = input:float():div(255)
-      ground_truth = ground_truth:float():div(255)
-      
       t = sys.clock()
       t = sys.clock()
       if input:size(3) == ground_truth:size(3) then
       if input:size(3) == ground_truth:size(3) then
 	 model1_output = reconstruct.image(model1, input)
 	 model1_output = reconstruct.image(model1, input)
@@ -93,7 +96,7 @@ local function benchmark(opt, x, input_func, model1, model2)
 	 if model2 then
 	 if model2 then
 	    model2_output = reconstruct.scale(model2, 2.0, input)
 	    model2_output = reconstruct.scale(model2, 2.0, input)
 	 end
 	 end
-	 baseline_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, opt.filter)
+	 baseline_output = baseline_scale(input, opt.filter)
       end
       end
       if opt.color == "y" then
       if opt.color == "y" then
 	 model1_mse = model1_mse + YMSE(ground_truth, model1_output)
 	 model1_mse = model1_mse + YMSE(ground_truth, model1_output)
@@ -162,7 +165,7 @@ local function load_data(test_dir)
    local test_x = {}
    local test_x = {}
    local files = dir.getfiles(test_dir, "*.*")
    local files = dir.getfiles(test_dir, "*.*")
    for i = 1, #files do
    for i = 1, #files do
-      table.insert(test_x, iproc.crop_mod4(image_loader.load_byte(files[i])))
+      table.insert(test_x, iproc.crop_mod4(image_loader.load_float(files[i])))
       xlua.progress(i, #files)
       xlua.progress(i, #files)
    end
    end
    return test_x
    return test_x