make_benchmark_input.lua 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  4. require 'xlua'
  5. local iproc = require 'iproc'
  6. local image_loader = require 'image_loader'
  7. local gm = require 'graphicsmagick'
  8. local cmd = torch.CmdLine()
  9. cmd:text()
  10. cmd:text("waifu2x-make benchmark data")
  11. cmd:text("Options:")
  12. cmd:option("-i", "./data/test", 'input dir')
  13. cmd:option("-lr", "hr", 'highres output dir')
  14. cmd:option("-hr", "lr", 'lowres output dir')
  15. cmd:option("-filter", "Sinc", 'dowsampling filter')
  16. local opt = cmd:parse(arg)
  17. torch.setdefaulttensortype('torch.FloatTensor')
  18. local function transform_scale(x, opt)
  19. return iproc.scale(x,
  20. x:size(3) * 0.5,
  21. x:size(2) * 0.5,
  22. opt.filter, 1)
  23. end
  24. local function load_data_from_dir(test_dir)
  25. local test_x = {}
  26. local files = dir.getfiles(test_dir, "*.*")
  27. for i = 1, #files do
  28. local name = path.basename(files[i])
  29. local e = path.extension(name)
  30. local base = name:sub(0, name:len() - e:len())
  31. local img = image_loader.load_byte(files[i])
  32. if img then
  33. table.insert(test_x, {y = iproc.crop_mod4(img),
  34. basename = base})
  35. end
  36. if i % 10 == 0 then
  37. if opt.show_progress then
  38. xlua.progress(i, #files)
  39. end
  40. collectgarbage()
  41. end
  42. end
  43. return test_x
  44. end
  45. dir.makepath(opt.lr)
  46. dir.makepath(opt.hr)
  47. local files = load_data_from_dir(opt.i)
  48. for i = 1, #files do
  49. local y = files[i].y
  50. local x = transform_scale(y, opt)
  51. local hr_path = path.join(opt.hr, files[i].basename .. ".png")
  52. local lr_path = path.join(opt.lr, files[i].basename .. ".png")
  53. image.save(hr_path, y)
  54. image.save(lr_path, x)
  55. end