make_pairwise_list.lua 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. require 'pl'
  2. local cjson = require 'cjson'
  3. local function pairwise_from_entries(y_dir, x_dirs)
  4. local list = {}
  5. local y_files = dir.getfiles(y_dir, "*")
  6. for i, y_file in ipairs(y_files) do
  7. local basename = path.basename(y_file)
  8. local x_files = {}
  9. for i = 1, #x_dirs do
  10. local x_file = path.join(x_dirs[i], basename)
  11. if path.exists(x_file) then
  12. table.insert(x_files, x_file)
  13. end
  14. end
  15. if #x_files == 1 then
  16. table.insert(list, {y = y_file, x = x_files[1]})
  17. elseif #x_files > 1 then
  18. local r = torch.random(1, #x_files)
  19. table.insert(list, {y = y_file, x = x_files[r]})
  20. end
  21. end
  22. return list
  23. end
  24. local function pairwise_from_list(y_dir, x_dirs, basename_file)
  25. local list = {}
  26. local basenames = utils.split(file.read(basename_file), "\n")
  27. for i, basename in ipairs(basenames) do
  28. local basename = path.basename(basename)
  29. local y_file = path.join(y_dir, basename)
  30. if path.exists(y_file) then
  31. local x_files = {}
  32. for i = 1, #x_dirs do
  33. local x_file = path.join(x_dirs[i], basename)
  34. if path.exists(x_file) then
  35. table.insert(x_files, x_file)
  36. end
  37. end
  38. if #x_files == 1 then
  39. table.insert(list, {y = y_file, x = x_files[1]})
  40. elseif #x_files > 1 then
  41. local r = torch.random(1, #x_files)
  42. table.insert(list, {y = y_file, x = x_files[r]})
  43. end
  44. end
  45. end
  46. return list
  47. end
  48. local function output(list, filters, rate)
  49. local n = math.floor(#list * rate)
  50. if #list > 0 and n == 0 then
  51. n = 1
  52. end
  53. local perm = torch.randperm(#list)
  54. if #filters == 0 then
  55. filters = nil
  56. end
  57. for i = 1, n do
  58. local v = list[perm[i]]
  59. io.stdout:write('"' .. v.y:gsub('"', '""') .. '"' .. "," .. '"' .. cjson.encode({x = v.x, filters = filters}):gsub('"', '""') .. '"' .. "\n")
  60. end
  61. end
  62. local function get_xdirs(opt)
  63. local x_dirs = {}
  64. for k,v in pairs(opt) do
  65. local s, e = k:find("x_dir")
  66. if s == 1 then
  67. table.insert(x_dirs, v)
  68. end
  69. end
  70. return x_dirs
  71. end
  72. local cmd = torch.CmdLine()
  73. cmd:text("waifu2x make_pairwise_list")
  74. cmd:option("-x_dir", "", 'Specify the directory for x(input)')
  75. cmd:option("-y_dir", "", 'Specify the directory for y(groundtruth). The filenames should be same as x_dir')
  76. cmd:option("-rate", 1, 'sampling rate')
  77. cmd:option("-file_list", "", 'Specify the basename list (optional)')
  78. cmd:option("-filters", "", 'Specify the downsampling filters')
  79. cmd:option("-x_dir1", "", 'x for random choice')
  80. cmd:option("-x_dir2", "", 'x for random choice')
  81. cmd:option("-x_dir3", "", 'x for random choice')
  82. cmd:option("-x_dir4", "", 'x for random choice')
  83. cmd:option("-x_dir5", "", 'x for random choice')
  84. cmd:option("-x_dir6", "", 'x for random choice')
  85. cmd:option("-x_dir7", "", 'x for random choice')
  86. cmd:option("-x_dir8", "", 'x for random choice')
  87. cmd:option("-x_dir9", "", 'x for random choice')
  88. torch.manualSeed(71)
  89. local opt = cmd:parse(arg)
  90. local x_dirs = get_xdirs(opt)
  91. if opt.y_dir:len() == 0 or #x_dirs == 0 then
  92. cmd:help()
  93. os.exit(1)
  94. end
  95. local list
  96. if opt.file_list:len() > 0 then
  97. list = pairwise_from_list(opt.y_dir, x_dirs, opt.file_list)
  98. else
  99. list = pairwise_from_entries(opt.y_dir, x_dirs)
  100. end
  101. output(list, utils.split(opt.filters, ","), opt.rate)