Browse Source

Add support for the test filelist in benchmark

nagadomi 8 years ago
parent
commit
bfb67e61f4
1 changed files with 57 additions and 8 deletions
  1. 57 8
      tools/benchmark.lua

+ 57 - 8
tools/benchmark.lua

@@ -15,6 +15,7 @@ cmd:text("waifu2x-benchmark")
 cmd:text("Options:")
 
 cmd:option("-dir", "./data/test", 'test image directory')
+cmd:option("-file", "", 'test image file list')
 cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
 cmd:option("-model2_dir", "", 'model2 directory (optional)')
 cmd:option("-method", "scale", '(scale|noise|noise_scale|user)')
@@ -43,6 +44,8 @@ cmd:option("-yuv420", 0, 'use yuv420 jpeg')
 cmd:option("-name", "", 'model name for user method')
 cmd:option("-x_dir", "", 'input image for user method')
 cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be the same as x_dir')
+cmd:option("-x_file", "", 'input image for user method')
+cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -431,7 +434,7 @@ local function benchmark(opt, x, model1, model2)
    end
    io.stdout:write("\n")
 end
-local function load_data(test_dir)
+local function load_data_from_dir(test_dir)
    local test_x = {}
    local files = dir.getfiles(test_dir, "*.*")
    for i = 1, #files do
@@ -449,16 +452,47 @@ local function load_data(test_dir)
    end
    return test_x
 end
+local function load_data_from_file(test_file)
+   local test_x = {}
+   local files = utils.split(file.read(test_file), "\n")
+   for i = 1, #files do
+      local name = path.basename(files[i])
+      local e = path.extension(name)
+      local base = name:sub(0, name:len() - e:len())
+      local img = image_loader.load_float(files[i])
+      if img then
+	 table.insert(test_x, {y = iproc.crop_mod4(img),
+			       basename = base})
+      end
+      if opt.show_progress then
+	 xlua.progress(i, #files)
+      end
+   end
+   return test_x
+end
 local function get_basename(f)
    local name = path.basename(f)
    local e = path.extension(name)
    local base = name:sub(0, name:len() - e:len())
    return base
 end
-local function load_user_data(y_dir, x_dir)
+local function load_user_data(y_dir, y_file, x_dir, x_file)
    local test = {}
-   local y_files = dir.getfiles(y_dir, "*.*")
-   local x_files = dir.getfiles(x_dir, "*.*")
+   local y_files
+   local x_files
+
+   if y_file:len() > 0 then
+      print(y_file)
+
+      y_files = utils.split(file.read(y_file), "\n")
+   else
+      y_files = dir.getfiles(y_dir, "*.*")
+   end
+   if x_file:len() > 0 then
+      x_files = utils.split(file.read(x_file), "\n")
+   else
+      x_files = dir.getfiles(x_dir, "*.*")
+   end
    local basename_db = {}
    for i = 1, #y_files do
       basename_db[get_basename(y_files[i])] = {y = y_files[i]}
@@ -535,7 +569,12 @@ if opt.method == "scale" then
    if not s2 then
       model2 = nil
    end
-   local test_x = load_data(opt.dir)
+   local test_x
+   if opt.file:len() > 0 then
+      test_x = load_data_from_file(opt.file)
+   else
+      test_x = load_data_from_dir(opt.dir)
+   end
    benchmark(opt, test_x, model1, model2)
 elseif opt.method == "noise" then
    local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
@@ -548,7 +587,12 @@ elseif opt.method == "noise" then
    if not s2 then
       model2 = nil
    end
-   local test_x = load_data(opt.dir)
+   local test_x
+   if opt.file:len() > 0 then
+      test_x = load_data_from_file(opt.file)
+   else
+      test_x = load_data_from_dir(opt.dir)
+   end
    benchmark(opt, test_x, model1, model2)
 elseif opt.method == "noise_scale" then
    local model2 = nil
@@ -556,7 +600,12 @@ elseif opt.method == "noise_scale" then
    if opt.model2_dir:len() > 0 then
       model2 = load_noise_scale_model(opt.model2_dir, opt.noise_level, opt.force_cudnn)
    end
-   local test_x = load_data(opt.dir)
+   local test_x
+   if opt.file:len() > 0 then
+      test_x = load_data_from_file(opt.file)
+   else
+      test_x = load_data_from_dir(opt.dir)
+   end
    benchmark(opt, test_x, model1, model2)
 elseif opt.method == "user" then
    local f1 = path.join(opt.model1_dir, string.format("%s_model.t7", opt.name))
@@ -569,6 +618,6 @@ elseif opt.method == "user" then
    if not s2 then
       model2 = nil
    end
-   local test = load_user_data(opt.y_dir, opt.x_dir)
+   local test = load_user_data(opt.y_dir, opt.y_file, opt.x_dir, opt.x_file)
    benchmark(opt, test, model1, model2)
 end