Browse Source

Add helper script for user method

nagadomi 8 years ago
parent
commit
f72d756172
1 changed files with 109 additions and 0 deletions
  1. 109 0
      tools/make_pairwise_list.lua

+ 109 - 0
tools/make_pairwise_list.lua

@@ -0,0 +1,109 @@
+require 'pl'
+local cjson = require 'cjson'
+
+local function pairwise_from_entries(y_dir, x_dirs)
+   local list = {}
+   local y_files = dir.getfiles(y_dir, "*")
+   for i, y_file in ipairs(y_files) do
+      local basename = path.basename(y_file)
+      local x_files = {}
+      for i = 1, #x_dirs do
+	 local x_file = path.join(x_dirs[i], basename)
+	 if path.exists(x_file) then
+	    table.insert(x_files, x_file)
+	 end
+      end
+      if #x_files == 1 then
+	 table.insert(list, {y =  y_file, x = x_files[1]})
+      elseif #x_files > 1 then
+	 local r = torch.random(1, #x_files)
+	 table.insert(list, {y =  y_file, x = x_files[r]})
+      end
+   end
+   return list
+end
+local function pairwise_from_list(y_dir, x_dirs, basename_file)
+   local list = {}
+   local basenames = utils.split(file.read(basename_file), "\n")
+   for i, basename in ipairs(basenames) do
+      local basename = path.basename(basename)
+      local y_file = path.join(y_dir, basename)
+      if path.exists(y_file) then
+	 local x_files = {}
+	 for i = 1, #x_dirs do
+	    local x_file = path.join(x_dirs[i], basename)
+	    if path.exists(x_file) then
+	       table.insert(x_files, x_file)
+	    end
+	 end
+	 if #x_files == 1 then
+	    table.insert(list, {y =  y_file, x = x_files[1]})
+	 elseif #x_files > 1 then
+	    local r = torch.random(1, #x_files)
+	    table.insert(list, {y =  y_file, x = x_files[r]})
+	 end
+      end
+   end
+   return list
+end
+local function output(list, filters, rate)
+   local n = math.floor(#list * rate)
+   if #list > 0 and n == 0 then
+      n = 1
+   end
+   local perm = torch.randperm(#list)
+   if #filters == 0 then
+      filters = nil
+   end
+   for i = 1, n do
+      local v = list[perm[i]]
+      io.stdout:write('"' .. v.y:gsub('"', '""') .. '"' .. "," .. '"' .. cjson.encode({x = v.x, filters = filters}):gsub('"', '""') .. '"' .. "\n")
+   end
+end
+local function get_xdirs(opt)
+   local x_dirs = {}
+   for k,v in pairs(opt) do
+      local s, e = k:find("x_dir")
+      if s == 1 then
+	 table.insert(x_dirs, v)
+      end
+   end
+   return x_dirs
+end
+
+local cmd = torch.CmdLine()
+cmd:text("waifu2x make_pairwise_list")
+cmd:option("-x_dir", "", 'Specify the directory for x(input)')
+cmd:option("-y_dir", "", 'Specify the directory for y(groundtruth). The filenames should be same as x_dir')
+cmd:option("-rate", 1, 'sampling rate')
+cmd:option("-file_list", "", 'Specify the basename list (optional)')
+cmd:option("-filters", "", 'Specify the downsampling filters')
+cmd:option("-x_dir1", "", 'x for random choice')
+cmd:option("-x_dir2", "", 'x for random choice')
+cmd:option("-x_dir3", "", 'x for random choice')
+cmd:option("-x_dir4", "", 'x for random choice')
+cmd:option("-x_dir5", "", 'x for random choice')
+cmd:option("-x_dir6", "", 'x for random choice')
+cmd:option("-x_dir7", "", 'x for random choice')
+cmd:option("-x_dir8", "", 'x for random choice')
+cmd:option("-x_dir9", "", 'x for random choice')
+cmd:option("-x_dir9", "", 'x for random choice')
+
+torch.manualSeed(71)
+
+local opt = cmd:parse(arg)
+
+local x_dirs = get_xdirs(opt)
+
+if opt.y_dir:len() == 0 or #x_dirs == 0 then
+   cmd:help()
+   os.exit(1)
+end
+
+local list
+if opt.file_list:len() > 0 then
+   list = pairwise_from_list(opt.y_dir, x_dirs, opt.file_list)
+else
+   list = pairwise_from_entries(opt.y_dir, x_dirs)
+end
+output(list, utils.split(opt.filters, ","), opt.rate)