settings.lua 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. require 'xlua'
  2. require 'pl'
  3. require 'trepl'
  4. -- global settings
  5. if package.preload.settings then
  6. return package.preload.settings
  7. end
  8. -- default tensor type
  9. torch.setdefaulttensortype('torch.FloatTensor')
  10. local settings = {}
  11. local cmd = torch.CmdLine()
  12. cmd:text()
  13. cmd:text("waifu2x")
  14. cmd:text("Options:")
  15. cmd:option("-seed", 11, 'fixed input seed')
  16. cmd:option("-data_dir", "./data", 'data directory')
  17. -- cmd:option("-backend", "cunn", '(cunn|cudnn)') -- cudnn is slow than cunn
  18. cmd:option("-test", "images/miku_small.png", 'test image file')
  19. cmd:option("-model_dir", "./models", 'model directory')
  20. cmd:option("-method", "scale", '(noise|scale|noise_scale)')
  21. cmd:option("-noise_level", 1, '(1|2)')
  22. cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
  23. cmd:option("-color", 'rgb', '(y|rgb)')
  24. cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)')
  25. cmd:option("-overlay", 0, 'enable data augmentation using overlay (1|0)')
  26. cmd:option("-scale", 2.0, 'scale')
  27. cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
  28. cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
  29. cmd:option("-crop_size", 128, 'crop size')
  30. cmd:option("-max_size", -1, 'crop if image size larger then this value.')
  31. cmd:option("-batch_size", 2, 'mini batch size')
  32. cmd:option("-epoch", 200, 'epoch')
  33. cmd:option("-thread", -1, 'number of CPU threads')
  34. cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
  35. cmd:option("-validation_ratio", 0.1, 'validation ratio')
  36. cmd:option("-validation_crops", 40, 'number of crop region in validation')
  37. cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
  38. cmd:option("-active_cropping_tries", 20, 'active cropping tries')
  39. local opt = cmd:parse(arg)
  40. for k, v in pairs(opt) do
  41. settings[k] = v
  42. end
  43. if settings.method == "noise" then
  44. settings.model_file = string.format("%s/noise%d_model.t7",
  45. settings.model_dir, settings.noise_level)
  46. elseif settings.method == "scale" then
  47. settings.model_file = string.format("%s/scale%.1fx_model.t7",
  48. settings.model_dir, settings.scale)
  49. elseif settings.method == "noise_scale" then
  50. settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
  51. settings.model_dir, settings.noise_level, settings.scale)
  52. else
  53. error("unknown method: " .. settings.method)
  54. end
  55. if not (settings.color == "rgb" or settings.color == "y") then
  56. error("color must be y or rgb")
  57. end
  58. if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
  59. error("scale must be mod-2")
  60. end
  61. if not (settings.category == "anime_style_art" or
  62. settings.category == "photo") then
  63. error(string.format("unknown category: %s", settings.category))
  64. end
  65. if settings.random_half == 1 then
  66. settings.random_half = true
  67. else
  68. settings.random_half = false
  69. end
  70. if settings.color_noise == 1 then
  71. settings.color_noise = true
  72. else
  73. settings.color_noise = false
  74. end
  75. if settings.overlay == 1 then
  76. settings.overlay = true
  77. else
  78. settings.overlay = false
  79. end
  80. if settings.thread > 0 then
  81. torch.setnumthreads(tonumber(settings.thread))
  82. end
  83. settings.images = string.format("%s/images.t7", settings.data_dir)
  84. settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
  85. settings.backend = "cunn"
  86. return settings