settings.lua 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. require 'xlua'
  2. require 'pl'
  3. -- global settings
  4. if package.preload.settings then
  5. return package.preload.settings
  6. end
  7. -- default tensor type
  8. torch.setdefaulttensortype('torch.FloatTensor')
  9. local settings = {}
  10. local cmd = torch.CmdLine()
  11. cmd:text()
  12. cmd:text("waifu2x")
  13. cmd:text("Options:")
  14. cmd:option("-seed", 11, 'fixed input seed')
  15. cmd:option("-data_dir", "./data", 'data directory')
  16. cmd:option("-test", "images/miku_small.png", 'test image file')
  17. cmd:option("-model_dir", "./models", 'model directory')
  18. cmd:option("-method", "scale", '(noise|scale|noise_scale)')
  19. cmd:option("-noise_level", 1, '(1|2)')
  20. cmd:option("-scale", 2.0, 'scale')
  21. cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
  22. cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
  23. cmd:option("-crop_size", 128, 'crop size')
  24. cmd:option("-batch_size", 2, 'mini batch size')
  25. cmd:option("-epoch", 200, 'epoch')
  26. cmd:option("-core", 2, 'cpu core')
  27. local opt = cmd:parse(arg)
  28. for k, v in pairs(opt) do
  29. settings[k] = v
  30. end
  31. if settings.method == "noise" then
  32. settings.model_file = string.format("%s/noise%d_model.t7",
  33. settings.model_dir, settings.noise_level)
  34. elseif settings.method == "scale" then
  35. settings.model_file = string.format("%s/scale%.1fx_model.t7",
  36. settings.model_dir, settings.scale)
  37. elseif settings.method == "noise_scale" then
  38. settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
  39. settings.model_dir, settings.noise_level, settings.scale)
  40. else
  41. error("unknown method: " .. settings.method)
  42. end
  43. if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
  44. error("scale must be mod-2")
  45. end
  46. if settings.random_half == 1 then
  47. settings.random_half = true
  48. else
  49. settings.random_half = false
  50. end
  51. torch.setnumthreads(settings.core)
  52. settings.images = string.format("%s/images.t7", settings.data_dir)
  53. settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
  54. settings.validation_ratio = 0.1
  55. settings.validation_crops = 40
  56. local srcnn = require './srcnn'
  57. if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
  58. settings.create_model = srcnn.waifu4x
  59. settings.block_offset = 13
  60. else
  61. settings.create_model = srcnn.waifu2x
  62. settings.block_offset = 7
  63. end
  64. return settings