settings.lua 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. require 'xlua'
  2. require 'pl'
  3. require 'trepl'
  4. -- global settings
  5. if package.preload.settings then
  6. return package.preload.settings
  7. end
  8. -- default tensor type
  9. torch.setdefaulttensortype('torch.FloatTensor')
  10. local settings = {}
  11. local cmd = torch.CmdLine()
  12. cmd:text()
  13. cmd:text("waifu2x-training")
  14. cmd:text("Options:")
  15. cmd:option("-seed", 11, 'RNG seed')
  16. cmd:option("-data_dir", "./data", 'path to data directory')
  17. cmd:option("-backend", "cunn", '(cunn|cudnn)')
  18. cmd:option("-test", "images/miku_small.png", 'path to test image')
  19. cmd:option("-model_dir", "./models", 'model directory')
  20. cmd:option("-method", "scale", 'method to training (noise|scale)')
  21. cmd:option("-noise_level", 1, '(1|2)')
  22. cmd:option("-style", "art", '(art|photo)')
  23. cmd:option("-color", 'rgb', '(y|rgb)')
  24. cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)')
  25. cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)')
  26. cmd:option("-scale", 2.0, 'scale factor (2)')
  27. cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
  28. cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)')
  29. cmd:option("-crop_size", 46, 'crop size')
  30. cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
  31. cmd:option("-batch_size", 8, 'mini batch size')
  32. cmd:option("-epoch", 200, 'number of total epochs to run')
  33. cmd:option("-thread", -1, 'number of CPU threads')
  34. cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
  35. cmd:option("-validation_rate", 0.05, 'validation-set rate of data')
  36. cmd:option("-validation_crops", 80, 'number of region per image in validation')
  37. cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
  38. cmd:option("-active_cropping_tries", 10, 'active cropping tries')
  39. cmd:option("-nr_rate", 0.7, 'trade-off between reducing noise and erasing details (0.0-1.0)')
  40. local opt = cmd:parse(arg)
  41. for k, v in pairs(opt) do
  42. settings[k] = v
  43. end
  44. if settings.method == "noise" then
  45. settings.model_file = string.format("%s/noise%d_model.t7",
  46. settings.model_dir, settings.noise_level)
  47. elseif settings.method == "scale" then
  48. settings.model_file = string.format("%s/scale%.1fx_model.t7",
  49. settings.model_dir, settings.scale)
  50. elseif settings.method == "noise_scale" then
  51. settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
  52. settings.model_dir, settings.noise_level, settings.scale)
  53. else
  54. error("unknown method: " .. settings.method)
  55. end
  56. if not (settings.color == "rgb" or settings.color == "y") then
  57. error("color must be y or rgb")
  58. end
  59. if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
  60. error("scale must be mod-2")
  61. end
  62. if not (settings.style == "art" or
  63. settings.style == "photo") then
  64. error(string.format("unknown style: %s", settings.style))
  65. end
  66. if settings.random_half == 1 then
  67. settings.random_half = true
  68. else
  69. settings.random_half = false
  70. end
  71. if settings.color_noise == 1 then
  72. settings.color_noise = true
  73. else
  74. settings.color_noise = false
  75. end
  76. if settings.overlay == 1 then
  77. settings.overlay = true
  78. else
  79. settings.overlay = false
  80. end
  81. if settings.thread > 0 then
  82. torch.setnumthreads(tonumber(settings.thread))
  83. end
  84. settings.images = string.format("%s/images.t7", settings.data_dir)
  85. settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
  86. return settings