settings.lua 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. require 'xlua'
  2. require 'pl'
  3. -- global settings
  4. if package.preload.settings then
  5. return package.preload.settings
  6. end
  7. -- default tensor type
  8. torch.setdefaulttensortype('torch.FloatTensor')
  9. local settings = {}
  10. local cmd = torch.CmdLine()
  11. cmd:text()
  12. cmd:text("waifu2x")
  13. cmd:text("Options:")
  14. cmd:option("-seed", 11, 'fixed input seed')
  15. cmd:option("-data_dir", "./data", 'data directory')
  16. cmd:option("-test", "images/miku_small.png", 'test image file')
  17. cmd:option("-model_dir", "./models", 'model directory')
  18. cmd:option("-method", "scale", '(noise|scale|noise_scale)')
  19. cmd:option("-noise_level", 1, '(1|2)')
  20. cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
  21. cmd:option("-color", 'rgb', '(y|rgb)')
  22. cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)')
  23. cmd:option("-scale", 2.0, 'scale')
  24. cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
  25. cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
  26. cmd:option("-crop_size", 128, 'crop size')
  27. cmd:option("-batch_size", 2, 'mini batch size')
  28. cmd:option("-epoch", 200, 'epoch')
  29. cmd:option("-core", 2, 'cpu core')
  30. local opt = cmd:parse(arg)
  31. for k, v in pairs(opt) do
  32. settings[k] = v
  33. end
  34. if settings.method == "noise" then
  35. settings.model_file = string.format("%s/noise%d_model.t7",
  36. settings.model_dir, settings.noise_level)
  37. elseif settings.method == "scale" then
  38. settings.model_file = string.format("%s/scale%.1fx_model.t7",
  39. settings.model_dir, settings.scale)
  40. elseif settings.method == "noise_scale" then
  41. settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
  42. settings.model_dir, settings.noise_level, settings.scale)
  43. else
  44. error("unknown method: " .. settings.method)
  45. end
  46. if not (settings.color == "rgb" or settings.color == "y") then
  47. error("color must be y or rgb")
  48. end
  49. if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
  50. error("scale must be mod-2")
  51. end
  52. if not (settings.category == "anime_style_art" or
  53. settings.category == "photo") then
  54. error(string.format("unknown category: %s", settings.category))
  55. end
  56. if settings.random_half == 1 then
  57. settings.random_half = true
  58. else
  59. settings.random_half = false
  60. end
  61. if settings.color_noise == 1 then
  62. settings.color_noise = true
  63. else
  64. settings.color_noise = false
  65. end
  66. torch.setnumthreads(settings.core)
  67. settings.images = string.format("%s/images.t7", settings.data_dir)
  68. settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
  69. settings.validation_ratio = 0.1
  70. settings.validation_crops = 30
  71. local srcnn = require './srcnn'
  72. if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
  73. settings.create_model = srcnn.waifu4x
  74. settings.block_offset = 13
  75. else
  76. settings.create_model = srcnn.waifu2x
  77. settings.block_offset = 7
  78. end
  79. return settings