settings.lua 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. require 'xlua'
  2. require 'pl'
  3. require 'trepl'
  4. require 'cutorch'
  5. -- global settings
  6. if package.preload.settings then
  7. return package.preload.settings
  8. end
  9. -- default tensor type
  10. torch.setdefaulttensortype('torch.FloatTensor')
  11. local settings = {}
  12. local cmd = torch.CmdLine()
  13. cmd:text()
  14. cmd:text("waifu2x-training")
  15. cmd:text("Options:")
  16. cmd:option("-gpu", -1, 'GPU Device ID')
  17. cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
  18. cmd:option("-data_dir", "./data", 'path to data directory')
  19. cmd:option("-backend", "cunn", '(cunn|cudnn)')
  20. cmd:option("-test", "images/miku_small.png", 'path to test image')
  21. cmd:option("-model_dir", "./models", 'model directory')
  22. cmd:option("-method", "scale", 'method to training (noise|scale|noise_scale|user)')
  23. cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)')
  24. cmd:option("-noise_level", 1, '(0|1|2|3)')
  25. cmd:option("-style", "art", '(art|photo)')
  26. cmd:option("-color", 'rgb', '(y|rgb)')
  27. cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
  28. cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
  29. cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
  30. cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
  31. cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0-1.0)')
  32. cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)')
  33. cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur')
  34. cmd:option("-random_blur_sigma_max", 1.0, 'max sigma for random gaussian blur')
  35. cmd:option("-random_pairwise_scale_rate", 0.0, 'data augmentation using pairwise resize for user method')
  36. cmd:option("-random_pairwise_scale_min", 0.85, 'min scale factor for random pairwise scale')
  37. cmd:option("-random_pairwise_scale_max", 1.176, 'max scale factor for random pairwise scale')
  38. cmd:option("-random_pairwise_rotate_rate", 0.0, 'data augmentation using pairwise resize for user method')
  39. cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairwise rotate')
  40. cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate')
  41. cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
  42. cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
  43. cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
  44. cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
  45. cmd:option("-scale", 2.0, 'scale factor (2)')
  46. cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
  47. cmd:option("-crop_size", 48, 'crop size')
  48. cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly')
  49. cmd:option("-batch_size", 16, 'mini batch size')
  50. cmd:option("-patches", 64, 'number of patch samples')
  51. cmd:option("-inner_epoch", 4, 'number of inner epochs')
  52. cmd:option("-epoch", 50, 'number of epochs to run')
  53. cmd:option("-thread", -1, 'number of CPU threads')
  54. cmd:option("-jpeg_chroma_subsampling_rate", 0.5, 'the rate of using YUV 4:2:0 in denoising training (0.0-1.0)')
  55. cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
  56. cmd:option("-validation_crops", 200, 'number of cropping region per image in validation')
  57. cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
  58. cmd:option("-active_cropping_tries", 10, 'active cropping tries')
  59. cmd:option("-nr_rate", 0.65, 'trade-off between reducing noise and erasing details (0.0-1.0)')
  60. cmd:option("-save_history", 0, 'save all model (0|1)')
  61. cmd:option("-plot", 0, 'plot loss chart(0|1)')
  62. cmd:option("-downsampling_filters", "Box,Lanczos,Sinc", '(comma separated)downsampling filters for 2x scale training. (Point,Box,Triangle,Hermite,Hanning,Hamming,Blackman,Gaussian,Quadratic,Cubic,Catrom,Mitchell,Lanczos,Bessel,Sinc)')
  63. cmd:option("-max_training_image_size", -1, 'if training image is larger than N, image will be crop randomly when data converting')
  64. cmd:option("-use_transparent_png", 0, 'use transparent png (0|1)')
  65. cmd:option("-resize_blur_min", 0.95, 'min blur parameter for ResizeImage')
  66. cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage')
  67. cmd:option("-oracle_rate", 0.1, '')
  68. cmd:option("-oracle_drop_rate", 0.5, '')
  69. cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
  70. cmd:option("-resume", "", 'resume model file')
  71. cmd:option("-name", "user", 'model name for user method')
  72. cmd:option("-gpu", 1, 'Device ID')
  73. cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
  74. cmd:option("-update_criterion", "mse", 'mse|loss')
  75. local function to_bool(settings, name)
  76. if settings[name] == 1 then
  77. settings[name] = true
  78. else
  79. settings[name] = false
  80. end
  81. end
  82. local opt = cmd:parse(arg)
  83. for k, v in pairs(opt) do
  84. settings[k] = v
  85. end
  86. to_bool(settings, "plot")
  87. to_bool(settings, "save_history")
  88. to_bool(settings, "use_transparent_png")
  89. to_bool(settings, "pairwise_y_binary")
  90. to_bool(settings, "pairwise_flip")
  91. if settings.plot then
  92. require 'gnuplot'
  93. end
  94. if settings.save_history then
  95. if settings.method == "noise" then
  96. settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7",
  97. settings.model_dir, settings.noise_level)
  98. settings.model_file_best = string.format("%s/noise%d_model.t7",
  99. settings.model_dir, settings.noise_level)
  100. elseif settings.method == "scale" then
  101. settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7",
  102. settings.model_dir, settings.scale)
  103. settings.model_file_best = string.format("%s/scale%.1fx_model.t7",
  104. settings.model_dir, settings.scale)
  105. elseif settings.method == "noise_scale" then
  106. settings.model_file = string.format("%s/noise%d_scale%.1fx_model.%%d-%%d.t7",
  107. settings.model_dir,
  108. settings.noise_level,
  109. settings.scale)
  110. settings.model_file_best = string.format("%s/noise%d_scale%.1fx_model.t7",
  111. settings.model_dir,
  112. settings.noise_level,
  113. settings.scale)
  114. elseif settings.method == "user" then
  115. settings.model_file = string.format("%s/%s_model.%%d-%%d.t7",
  116. settings.model_dir,
  117. settings.name)
  118. settings.model_file_best = string.format("%s/%s_model.t7",
  119. settings.model_dir,
  120. settings.name)
  121. else
  122. error("unknown method: " .. settings.method)
  123. end
  124. else
  125. if settings.method == "noise" then
  126. settings.model_file = string.format("%s/noise%d_model.t7",
  127. settings.model_dir, settings.noise_level)
  128. elseif settings.method == "scale" then
  129. settings.model_file = string.format("%s/scale%.1fx_model.t7",
  130. settings.model_dir, settings.scale)
  131. elseif settings.method == "noise_scale" then
  132. settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
  133. settings.model_dir, settings.noise_level, settings.scale)
  134. elseif settings.method == "user" then
  135. settings.model_file = string.format("%s/%s_model.t7",
  136. settings.model_dir, settings.name)
  137. else
  138. error("unknown method: " .. settings.method)
  139. end
  140. end
  141. if not (settings.color == "rgb" or settings.color == "y") then
  142. error("color must be y or rgb")
  143. end
  144. if not ( settings.scale == 1 or (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0)) then
  145. error("scale must be 1 or mod-2")
  146. end
  147. if not (settings.style == "art" or
  148. settings.style == "photo") then
  149. error(string.format("unknown style: %s", settings.style))
  150. end
  151. if settings.thread > 0 then
  152. torch.setnumthreads(tonumber(settings.thread))
  153. end
  154. if settings.downsampling_filters and settings.downsampling_filters:len() > 0 then
  155. settings.downsampling_filters = settings.downsampling_filters:split(",")
  156. else
  157. settings.downsampling_filters = {"Box", "Lanczos", "Catrom"}
  158. end
  159. settings.images = string.format("%s/images.t7", settings.data_dir)
  160. settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
  161. cutorch.setDevice(opt.gpu)
  162. return settings