w2nn.lua 789 B

1234567891011121314151617181920212223242526272829303132333435
  1. local function load_nn()
  2. require 'torch'
  3. require 'nn'
  4. end
  5. local function load_cunn()
  6. require 'cutorch'
  7. require 'cunn'
  8. end
  9. local function load_cudnn()
  10. require 'cudnn'
  11. cudnn.benchmark = true
  12. end
  13. if w2nn then
  14. return w2nn
  15. else
  16. local state, ret = pcall(load_cunn)
  17. if not state then
  18. error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
  19. end
  20. pcall(load_cudnn)
  21. w2nn = {}
  22. function w2nn.load_model(model_path, force_cudnn)
  23. local model = torch.load(model_path, "ascii")
  24. if force_cudnn then
  25. model = cudnn.convert(model, cudnn)
  26. end
  27. model:cuda():evaluate()
  28. return model
  29. end
  30. require 'LeakyReLU'
  31. require 'ClippedWeightedHuberCriterion'
  32. require 'ClippedMSECriterion'
  33. return w2nn
  34. end