|
@@ -15,6 +15,7 @@ cmd:text("waifu2x-benchmark")
|
|
|
cmd:text("Options:")
|
|
|
|
|
|
cmd:option("-dir", "./data/test", 'test image directory')
|
|
|
+cmd:option("-file", "", 'test image file list')
|
|
|
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
|
|
|
cmd:option("-model2_dir", "", 'model2 directory (optional)')
|
|
|
cmd:option("-method", "scale", '(scale|noise|noise_scale|user)')
|
|
@@ -43,6 +44,8 @@ cmd:option("-yuv420", 0, 'use yuv420 jpeg')
|
|
|
cmd:option("-name", "", 'model name for user method')
|
|
|
cmd:option("-x_dir", "", 'input image for user method')
|
|
|
cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be the same as x_dir')
|
|
|
+cmd:option("-x_file", "", 'input image for user method')
|
|
|
+cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
|
|
|
|
|
|
local function to_bool(settings, name)
|
|
|
if settings[name] == 1 then
|
|
@@ -431,7 +434,7 @@ local function benchmark(opt, x, model1, model2)
|
|
|
end
|
|
|
io.stdout:write("\n")
|
|
|
end
|
|
|
-local function load_data(test_dir)
|
|
|
+local function load_data_from_dir(test_dir)
|
|
|
local test_x = {}
|
|
|
local files = dir.getfiles(test_dir, "*.*")
|
|
|
for i = 1, #files do
|
|
@@ -449,16 +452,47 @@ local function load_data(test_dir)
|
|
|
end
|
|
|
return test_x
|
|
|
end
|
|
|
+local function load_data_from_file(test_file)
|
|
|
+ local test_x = {}
|
|
|
+ local files = utils.split(file.read(test_file), "\n")
|
|
|
+ for i = 1, #files do
|
|
|
+ local name = path.basename(files[i])
|
|
|
+ local e = path.extension(name)
|
|
|
+ local base = name:sub(0, name:len() - e:len())
|
|
|
+ local img = image_loader.load_float(files[i])
|
|
|
+ if img then
|
|
|
+ table.insert(test_x, {y = iproc.crop_mod4(img),
|
|
|
+ basename = base})
|
|
|
+ end
|
|
|
+ if opt.show_progress then
|
|
|
+ xlua.progress(i, #files)
|
|
|
+ end
|
|
|
+ end
|
|
|
+ return test_x
|
|
|
+end
|
|
|
local function get_basename(f)
|
|
|
local name = path.basename(f)
|
|
|
local e = path.extension(name)
|
|
|
local base = name:sub(0, name:len() - e:len())
|
|
|
return base
|
|
|
end
|
|
|
-local function load_user_data(y_dir, x_dir)
|
|
|
+local function load_user_data(y_dir, y_file, x_dir, x_file)
|
|
|
local test = {}
|
|
|
- local y_files = dir.getfiles(y_dir, "*.*")
|
|
|
- local x_files = dir.getfiles(x_dir, "*.*")
|
|
|
+ local y_files
|
|
|
+ local x_files
|
|
|
+
|
|
|
+ if y_file:len() > 0 then
|
|
|
+ print(y_file)
|
|
|
+
|
|
|
+ y_files = utils.split(file.read(y_file), "\n")
|
|
|
+ else
|
|
|
+ y_files = dir.getfiles(y_dir, "*.*")
|
|
|
+ end
|
|
|
+ if x_file:len() > 0 then
|
|
|
+ x_files = utils.split(file.read(x_file), "\n")
|
|
|
+ else
|
|
|
+ x_files = dir.getfiles(x_dir, "*.*")
|
|
|
+ end
|
|
|
local basename_db = {}
|
|
|
for i = 1, #y_files do
|
|
|
basename_db[get_basename(y_files[i])] = {y = y_files[i]}
|
|
@@ -535,7 +569,12 @@ if opt.method == "scale" then
|
|
|
if not s2 then
|
|
|
model2 = nil
|
|
|
end
|
|
|
- local test_x = load_data(opt.dir)
|
|
|
+ local test_x
|
|
|
+ if opt.file:len() > 0 then
|
|
|
+ test_x = load_data_from_file(opt.file)
|
|
|
+ else
|
|
|
+ test_x = load_data_from_dir(opt.dir)
|
|
|
+ end
|
|
|
benchmark(opt, test_x, model1, model2)
|
|
|
elseif opt.method == "noise" then
|
|
|
local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
|
|
@@ -548,7 +587,12 @@ elseif opt.method == "noise" then
|
|
|
if not s2 then
|
|
|
model2 = nil
|
|
|
end
|
|
|
- local test_x = load_data(opt.dir)
|
|
|
+ local test_x
|
|
|
+ if opt.file:len() > 0 then
|
|
|
+ test_x = load_data_from_file(opt.file)
|
|
|
+ else
|
|
|
+ test_x = load_data_from_dir(opt.dir)
|
|
|
+ end
|
|
|
benchmark(opt, test_x, model1, model2)
|
|
|
elseif opt.method == "noise_scale" then
|
|
|
local model2 = nil
|
|
@@ -556,7 +600,12 @@ elseif opt.method == "noise_scale" then
|
|
|
if opt.model2_dir:len() > 0 then
|
|
|
model2 = load_noise_scale_model(opt.model2_dir, opt.noise_level, opt.force_cudnn)
|
|
|
end
|
|
|
- local test_x = load_data(opt.dir)
|
|
|
+ local test_x
|
|
|
+ if opt.file:len() > 0 then
|
|
|
+ test_x = load_data_from_file(opt.file)
|
|
|
+ else
|
|
|
+ test_x = load_data_from_dir(opt.dir)
|
|
|
+ end
|
|
|
benchmark(opt, test_x, model1, model2)
|
|
|
elseif opt.method == "user" then
|
|
|
local f1 = path.join(opt.model1_dir, string.format("%s_model.t7", opt.name))
|
|
@@ -569,6 +618,6 @@ elseif opt.method == "user" then
|
|
|
if not s2 then
|
|
|
model2 = nil
|
|
|
end
|
|
|
- local test = load_user_data(opt.y_dir, opt.x_dir)
|
|
|
+ local test = load_user_data(opt.y_dir, opt.y_file, opt.x_dir, opt.x_file)
|
|
|
benchmark(opt, test, model1, model2)
|
|
|
end
|