w2nn.lua 594 B

12345678910111213141516171819202122232425262728
  1. local function load_nn()
  2. require 'torch'
  3. require 'nn'
  4. end
  5. local function load_cunn()
  6. require 'cutorch'
  7. require 'cunn'
  8. end
  9. local function load_cudnn()
  10. require 'cudnn'
  11. cudnn.benchmark = true
  12. end
  13. if w2nn then
  14. return w2nn
  15. else
  16. local state, ret = pcall(load_cunn)
  17. if not state then
  18. error("Failed to load CUDA modules. Please check the CUDA Settings.")
  19. end
  20. pcall(load_cudnn)
  21. w2nn = {}
  22. require 'LeakyReLU'
  23. require 'LeakyReLU_deprecated'
  24. require 'DepthExpand2x'
  25. require 'PSNRCriterion'
  26. require 'ClippedWeightedHuberCriterion'
  27. return w2nn
  28. end