settings.lua 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. require 'xlua'
  2. require 'pl'
  3. -- global settings
  4. if package.preload.settings then
  5. return package.preload.settings
  6. end
  7. -- default tensor type
  8. torch.setdefaulttensortype('torch.FloatTensor')
  9. local settings = {}
  10. local cmd = torch.CmdLine()
  11. cmd:text()
  12. cmd:text("waifu2x")
  13. cmd:text("Options:")
  14. cmd:option("-seed", 11, 'fixed input seed')
  15. cmd:option("-data_dir", "./data", 'data directory')
  16. cmd:option("-test", "images/miku_small.png", 'test image file')
  17. cmd:option("-model_dir", "./models", 'model directory')
  18. cmd:option("-method", "scale", '(noise|scale|noise_scale)')
  19. cmd:option("-noise_level", 1, '(1|2)')
  20. cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
  21. cmd:option("-color", 'rgb', '(y|rgb)')
  22. cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)')
  23. cmd:option("-overlay", 0, 'enable data augmentation using overlay (1|0)')
  24. cmd:option("-scale", 2.0, 'scale')
  25. cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
  26. cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
  27. cmd:option("-crop_size", 128, 'crop size')
  28. cmd:option("-batch_size", 2, 'mini batch size')
  29. cmd:option("-epoch", 200, 'epoch')
  30. cmd:option("-core", 2, 'cpu core')
  31. local opt = cmd:parse(arg)
  32. for k, v in pairs(opt) do
  33. settings[k] = v
  34. end
  35. if settings.method == "noise" then
  36. settings.model_file = string.format("%s/noise%d_model.t7",
  37. settings.model_dir, settings.noise_level)
  38. elseif settings.method == "scale" then
  39. settings.model_file = string.format("%s/scale%.1fx_model.t7",
  40. settings.model_dir, settings.scale)
  41. elseif settings.method == "noise_scale" then
  42. settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
  43. settings.model_dir, settings.noise_level, settings.scale)
  44. else
  45. error("unknown method: " .. settings.method)
  46. end
  47. if not (settings.color == "rgb" or settings.color == "y") then
  48. error("color must be y or rgb")
  49. end
  50. if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
  51. error("scale must be mod-2")
  52. end
  53. if not (settings.category == "anime_style_art" or
  54. settings.category == "photo") then
  55. error(string.format("unknown category: %s", settings.category))
  56. end
  57. if settings.random_half == 1 then
  58. settings.random_half = true
  59. else
  60. settings.random_half = false
  61. end
  62. if settings.color_noise == 1 then
  63. settings.color_noise = true
  64. else
  65. settings.color_noise = false
  66. end
  67. if settings.overlay == 1 then
  68. settings.overlay = true
  69. else
  70. settings.overlay = false
  71. end
  72. torch.setnumthreads(settings.core)
  73. settings.images = string.format("%s/images.t7", settings.data_dir)
  74. settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
  75. settings.validation_ratio = 0.1
  76. settings.validation_crops = 30
  77. local srcnn = require './srcnn'
  78. if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
  79. settings.create_model = srcnn.waifu4x
  80. settings.block_offset = 13
  81. else
  82. settings.create_model = srcnn.waifu2x
  83. settings.block_offset = 7
  84. end
  85. return settings