소스 검색

Reduce memory usage in benchmark

nagadomi 8 년 전
부모
커밋
02cf265d48
1개의 변경된 파일28개의 추가작업 그리고 16개의 파일을 삭제
  1. 28 16
      tools/benchmark.lua

+ 28 - 16
tools/benchmark.lua

@@ -227,12 +227,15 @@ local function benchmark(opt, x, model1, model2)
    end
 
    for i = 1, #x do
+      if i % 10 == 0 then
+	 collectgarbage()
+      end
       local basename = x[i].basename
       local input, model1_output, model2_output, baseline_output, ground_truth
 
       if opt.method == "scale" then
-	 input = transform_scale(x[i].y, opt)
-	 ground_truth = x[i].y
+	 input = transform_scale(iproc.byte2float(x[i].y), opt)
+	 ground_truth = iproc.byte2float(x[i].y)
 
 	 if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
 	    model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
@@ -250,8 +253,8 @@ local function benchmark(opt, x, model1, model2)
 	 end
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
       elseif opt.method == "scale4" then
-	 input = transform_scale4(x[i].y, opt)
-	 ground_truth = x[i].y
+	 input = transform_scale4(iproc.byte2float(x[i].y), opt)
+	 ground_truth = iproc.byte2float(x[i].y)
 	 if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
 	    model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
 	    if model2 then
@@ -270,8 +273,8 @@ local function benchmark(opt, x, model1, model2)
 	 end
 	 baseline_output = baseline_scale4(input, opt.baseline_filter)
       elseif opt.method == "noise" then
-	 input = transform_jpeg(x[i].y, opt)
-	 ground_truth = x[i].y
+	 input = transform_jpeg(iproc.byte2float(x[i].y), opt)
+	 ground_truth = iproc.byte2float(x[i].y)
 
 	 if opt.force_cudnn and i == 1 then
 	    model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
@@ -289,8 +292,8 @@ local function benchmark(opt, x, model1, model2)
 	 end
 	 baseline_output = input
       elseif opt.method == "noise_scale" then
-	 input = transform_scale_jpeg(x[i].y, opt)
-	 ground_truth = x[i].y
+	 input = transform_scale_jpeg(iproc.byte2float(x[i].y), opt)
+	 ground_truth = iproc.byte2float(x[i].y)
 
 	 if opt.force_cudnn and i == 1 then
 	    if model1.noise_scale_model then
@@ -355,8 +358,8 @@ local function benchmark(opt, x, model1, model2)
 	 end
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
       elseif opt.method == "user" then
-	 input = x[i].x
-	 ground_truth = x[i].y
+	 input = iproc.byte2float(x[i].x)
+	 ground_truth = iproc.byte2float(x[i].y)
 	 local y_scale = ground_truth:size(2) / input:size(2)
 	 if y_scale > 1 then
 	    if opt.force_cudnn and i == 1 then
@@ -390,8 +393,8 @@ local function benchmark(opt, x, model1, model2)
 	    end
 	 end
       elseif opt.method == "diff" then
-	 input = x[i].x
-	 ground_truth = x[i].y
+	 input = iproc.byte2float(x[i].x)
+	 ground_truth = iproc.byte2float(x[i].y)
 	 model1_output = input
       end
       if opt.border > 0 then
@@ -521,7 +524,7 @@ local function load_data_from_dir(test_dir)
       local name = path.basename(files[i])
       local e = path.extension(name)
       local base = name:sub(0, name:len() - e:len())
-      local img = image_loader.load_float(files[i])
+      local img = image_loader.load_byte(files[i])
       if img then
 	 table.insert(test_x, {y = iproc.crop_mod4(img),
 			       basename = base})
@@ -529,6 +532,9 @@ local function load_data_from_dir(test_dir)
       if opt.show_progress then
 	 xlua.progress(i, #files)
       end
+      if i % 10 == 0 then
+	 collectgarbage()
+      end
    end
    return test_x
 end
@@ -539,7 +545,7 @@ local function load_data_from_file(test_file)
       local name = path.basename(files[i])
       local e = path.extension(name)
       local base = name:sub(0, name:len() - e:len())
-      local img = image_loader.load_float(files[i])
+      local img = image_loader.load_byte(files[i])
       if img then
 	 table.insert(test_x, {y = iproc.crop_mod4(img),
 			       basename = base})
@@ -547,6 +553,9 @@ local function load_data_from_file(test_file)
       if opt.show_progress then
 	 xlua.progress(i, #files)
       end
+      if i % 10 == 0 then
+	 collectgarbage()
+      end
    end
    return test_x
 end
@@ -592,8 +601,8 @@ local function load_user_data(y_dir, y_file, x_dir, x_file)
    end
    for i = 1, #y_files do
       local key = get_basename(y_files[i])
-      local x = image_loader.load_float(basename_db[key].x)
-      local y = image_loader.load_float(basename_db[key].y)
+      local x = image_loader.load_byte(basename_db[key].x)
+      local y = image_loader.load_byte(basename_db[key].y)
       if x and y then
 	 table.insert(test, {y = y,
 			     x = x,
@@ -602,6 +611,9 @@ local function load_user_data(y_dir, y_file, x_dir, x_file)
       if opt.show_progress then
 	 xlua.progress(i, #y_files)
       end
+      if i % 10 == 0 then
+	 collectgarbage()
+      end
    end
    return test
 end