w2nn.lua 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. local function load_nn()
  2. require 'torch'
  3. require 'nn'
  4. end
  5. local function load_cunn()
  6. require 'cutorch'
  7. require 'cunn'
  8. end
  9. local function load_cudnn()
  10. cudnn = require('cudnn')
  11. end
  12. local function make_data_parallel_table(model, gpus)
  13. if cudnn then
  14. local fastest, benchmark = cudnn.fastest, cudnn.benchmark
  15. local dpt = nn.DataParallelTable(1, true, true)
  16. :add(model, gpus)
  17. :threads(function()
  18. require 'pl'
  19. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  20. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  21. require 'torch'
  22. require 'cunn'
  23. require 'w2nn'
  24. local cudnn = require 'cudnn'
  25. cudnn.fastest, cudnn.benchmark = fastest, benchmark
  26. end)
  27. dpt.gradInput = nil
  28. model = dpt:cuda()
  29. else
  30. local dpt = nn.DataParallelTable(1, true, true)
  31. :add(model, gpus)
  32. :threads(function()
  33. require 'pl'
  34. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  35. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  36. require 'torch'
  37. require 'cunn'
  38. require 'w2nn'
  39. end)
  40. dpt.gradInput = nil
  41. model = dpt:cuda()
  42. end
  43. return model
  44. end
  45. if w2nn then
  46. return w2nn
  47. else
  48. w2nn = {}
  49. local state, ret = pcall(load_cunn)
  50. if not state then
  51. error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
  52. end
  53. pcall(load_cudnn)
  54. function w2nn.load_model(model_path, force_cudnn, mode)
  55. mode = mode or "ascii"
  56. local model = torch.load(model_path, mode)
  57. if force_cudnn then
  58. model = cudnn.convert(model, cudnn)
  59. end
  60. model:cuda():evaluate()
  61. return model
  62. end
  63. function w2nn.data_parallel(model, gpus)
  64. if #gpus > 1 then
  65. return make_data_parallel_table(model, gpus)
  66. else
  67. return model
  68. end
  69. end
  70. require 'LeakyReLU'
  71. require 'ClippedWeightedHuberCriterion'
  72. require 'ClippedMSECriterion'
  73. require 'SSIMCriterion'
  74. require 'InplaceClip01'
  75. require 'L1Criterion'
  76. require 'ShakeShakeTable'
  77. require 'PrintTable'
  78. require 'Print'
  79. require 'AuxiliaryLossTable'
  80. require 'AuxiliaryLossCriterion'
  81. require 'GradWeight'
  82. require 'RandomBinaryConvolution'
  83. require 'RandomBinaryCriterion'
  84. require 'EdgeFilter'
  85. require 'ScaleTable'
  86. return w2nn
  87. end