settings.lua 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. require 'xlua'
  2. require 'pl'
  3. -- global settings
  4. if package.preload.settings then
  5. return package.preload.settings
  6. end
  7. -- default tensor type
  8. torch.setdefaulttensortype('torch.FloatTensor')
  9. local settings = {}
  10. local cmd = torch.CmdLine()
  11. cmd:text()
  12. cmd:text("waifu2x")
  13. cmd:text("Options:")
  14. cmd:option("-seed", 11, 'fixed input seed')
  15. cmd:option("-data_dir", "./data", 'data directory')
  16. cmd:option("-test", "images/miku_small.png", 'test image file')
  17. cmd:option("-model_dir", "./models", 'model directory')
  18. cmd:option("-method", "scale", '(noise|scale|noise_scale)')
  19. cmd:option("-noise_level", 1, '(1|2)')
  20. cmd:option("-color", 'rgb', '(y|rgb)')
  21. cmd:option("-scale", 2.0, 'scale')
  22. cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
  23. cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
  24. cmd:option("-crop_size", 128, 'crop size')
  25. cmd:option("-batch_size", 2, 'mini batch size')
  26. cmd:option("-epoch", 200, 'epoch')
  27. cmd:option("-core", 2, 'cpu core')
  28. local opt = cmd:parse(arg)
  29. for k, v in pairs(opt) do
  30. settings[k] = v
  31. end
  32. if settings.method == "noise" then
  33. settings.model_file = string.format("%s/noise%d_model.t7",
  34. settings.model_dir, settings.noise_level)
  35. elseif settings.method == "scale" then
  36. settings.model_file = string.format("%s/scale%.1fx_model.t7",
  37. settings.model_dir, settings.scale)
  38. elseif settings.method == "noise_scale" then
  39. settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
  40. settings.model_dir, settings.noise_level, settings.scale)
  41. else
  42. error("unknown method: " .. settings.method)
  43. end
  44. if not (settings.color == "rgb" or settings.color == "y") then
  45. error("color must be y or rgb")
  46. end
  47. if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
  48. error("scale must be mod-2")
  49. end
  50. if settings.random_half == 1 then
  51. settings.random_half = true
  52. else
  53. settings.random_half = false
  54. end
  55. torch.setnumthreads(settings.core)
  56. settings.images = string.format("%s/images.t7", settings.data_dir)
  57. settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
  58. settings.validation_ratio = 0.1
  59. settings.validation_crops = 40
  60. local srcnn = require './srcnn'
  61. if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
  62. settings.create_model = srcnn.waifu4x
  63. settings.block_offset = 13
  64. else
  65. settings.create_model = srcnn.waifu2x
  66. settings.block_offset = 7
  67. end
  68. return settings