settings.lua 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. require 'torch'
  2. require 'cutorch'
  3. require 'xlua'
  4. require 'pl'
  5. -- global settings
  6. if package.preload.settings then
  7. return package.preload.settings
  8. end
  9. -- default tensor type
  10. torch.setdefaulttensortype('torch.FloatTensor')
  11. local settings = {}
  12. local cmd = torch.CmdLine()
  13. cmd:text()
  14. cmd:text("waifu2x")
  15. cmd:text("Options:")
  16. cmd:option("-seed", 11, 'fixed input seed')
  17. cmd:option("-data_dir", "./data", 'data directory')
  18. cmd:option("-test", "images/miku_small.png", 'test image file')
  19. cmd:option("-model_dir", "./models", 'model directory')
  20. cmd:option("-method", "scale", '(noise|scale)')
  21. cmd:option("-noise_level", 1, '(1|2)')
  22. cmd:option("-scale", 2.0, 'scale')
  23. cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
  24. cmd:option("-crop_size", 128, 'crop size')
  25. cmd:option("-batch_size", 2, 'mini batch size')
  26. cmd:option("-epoch", 200, 'epoch')
  27. cmd:option("-core", 2, 'cpu core')
  28. local opt = cmd:parse(arg)
  29. for k, v in pairs(opt) do
  30. settings[k] = v
  31. end
  32. if settings.method == "noise" then
  33. settings.model_file = string.format("%s/noise%d_model.t7", settings.model_dir, settings.noise_level)
  34. elseif settings.method == "scale" then
  35. settings.model_file = string.format("%s/scale%.1fx_model.t7", settings.model_dir, settings.scale)
  36. settings.denoise_model_file = string.format("%s/noise%d_model.t7", settings.model_dir, settings.noise_level)
  37. else
  38. error("unknown method: " .. settings.method)
  39. end
  40. if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
  41. error("scale must be mod-2")
  42. end
  43. torch.setnumthreads(settings.core)
  44. settings.images = string.format("%s/images.t7", settings.data_dir)
  45. settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
  46. settings.validation_ratio = 01
  47. settings.validation_crops = 40
  48. settings.block_offset = 7 -- see srcnn.lua
  49. return settings