settings.lua 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. require 'xlua'
  2. require 'pl'
  3. require 'trepl'
  4. -- global settings
  5. if package.preload.settings then
  6. return package.preload.settings
  7. end
  8. -- default tensor type
  9. torch.setdefaulttensortype('torch.FloatTensor')
  10. local settings = {}
  11. local cmd = torch.CmdLine()
  12. cmd:text()
  13. cmd:text("waifu2x-training")
  14. cmd:text("Options:")
  15. cmd:option("-gpu", -1, 'GPU Device ID')
  16. cmd:option("-seed", 11, 'RNG seed')
  17. cmd:option("-data_dir", "./data", 'path to data directory')
  18. cmd:option("-backend", "cunn", '(cunn|cudnn)')
  19. cmd:option("-test", "images/miku_small.png", 'path to test image')
  20. cmd:option("-model_dir", "./models", 'model directory')
  21. cmd:option("-method", "scale", 'method to training (noise|scale)')
  22. cmd:option("-noise_level", 1, '(1|2|3)')
  23. cmd:option("-style", "art", '(art|photo)')
  24. cmd:option("-color", 'rgb', '(y|rgb)')
  25. cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
  26. cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
  27. cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
  28. cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
  29. cmd:option("-scale", 2.0, 'scale factor (2)')
  30. cmd:option("-learning_rate", 0.0005, 'learning rate for adam')
  31. cmd:option("-crop_size", 46, 'crop size')
  32. cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly')
  33. cmd:option("-batch_size", 8, 'mini batch size')
  34. cmd:option("-patches", 16, 'number of patch samples')
  35. cmd:option("-inner_epoch", 4, 'number of inner epochs')
  36. cmd:option("-epoch", 30, 'number of epochs to run')
  37. cmd:option("-thread", -1, 'number of CPU threads')
  38. cmd:option("-jpeg_chroma_subsampling_rate", 0.0, 'the rate of YUV 4:2:0/YUV 4:4:4 in denoising training (0.0-1.0)')
  39. cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
  40. cmd:option("-validation_crops", 160, 'number of cropping region per image in validation')
  41. cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
  42. cmd:option("-active_cropping_tries", 10, 'active cropping tries')
  43. cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
  44. cmd:option("-save_history", 0, 'save all model (0|1)')
  45. cmd:option("-plot", 0, 'plot loss chart(0|1)')
  46. cmd:option("-downsampling_filters", "Box,Catrom", '(comma separated)downsampling filters for 2x scale training. (Point,Box,Triangle,Hermite,Hanning,Hamming,Blackman,Gaussian,Quadratic,Cubic,Catrom,Mitchell,Lanczos,Bessel,Sinc)')
  47. cmd:option("-gamma_correction", 0, 'Resizing with colorspace correction(sRGB:gamma 2.2) in scale training (0|1)')
  48. cmd:option("-upsampling_filter", "Box", 'upsampling filter for 2x scale training (dev)')
  49. cmd:option("-max_training_image_size", -1, 'if training image is larger than N, image will be crop randomly when data converting')
  50. local function to_bool(settings, name)
  51. if settings[name] == 1 then
  52. settings[name] = true
  53. else
  54. settings[name] = false
  55. end
  56. end
  57. local opt = cmd:parse(arg)
  58. for k, v in pairs(opt) do
  59. settings[k] = v
  60. end
  61. to_bool(settings, "plot")
  62. to_bool(settings, "save_history")
  63. to_bool(settings, "gamma_correction")
  64. if settings.plot then
  65. require 'gnuplot'
  66. end
  67. if settings.save_history then
  68. if settings.method == "noise" then
  69. settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7",
  70. settings.model_dir, settings.noise_level)
  71. elseif settings.method == "scale" then
  72. settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7",
  73. settings.model_dir, settings.scale)
  74. else
  75. error("unknown method: " .. settings.method)
  76. end
  77. else
  78. if settings.method == "noise" then
  79. settings.model_file = string.format("%s/noise%d_model.t7",
  80. settings.model_dir, settings.noise_level)
  81. elseif settings.method == "scale" then
  82. settings.model_file = string.format("%s/scale%.1fx_model.t7",
  83. settings.model_dir, settings.scale)
  84. else
  85. error("unknown method: " .. settings.method)
  86. end
  87. end
  88. if not (settings.color == "rgb" or settings.color == "y") then
  89. error("color must be y or rgb")
  90. end
  91. if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
  92. error("scale must be mod-2")
  93. end
  94. if not (settings.style == "art" or
  95. settings.style == "photo") then
  96. error(string.format("unknown style: %s", settings.style))
  97. end
  98. if settings.thread > 0 then
  99. torch.setnumthreads(tonumber(settings.thread))
  100. end
  101. if settings.downsampling_filters and settings.downsampling_filters:len() > 0 then
  102. settings.downsampling_filters = settings.downsampling_filters:split(",")
  103. else
  104. settings.downsampling_filters = {"Box", "Lanczos", "Catrom"}
  105. end
  106. settings.images = string.format("%s/images.t7", settings.data_dir)
  107. settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
  108. return settings