w2nn.lua 826 B

123456789101112131415161718192021222324252627282930313233343536
  1. local function load_nn()
  2. require 'torch'
  3. require 'nn'
  4. end
  5. local function load_cunn()
  6. require 'cutorch'
  7. require 'cunn'
  8. end
  9. local function load_cudnn()
  10. cudnn = require('cudnn')
  11. end
  12. if w2nn then
  13. return w2nn
  14. else
  15. w2nn = {}
  16. local state, ret = pcall(load_cunn)
  17. if not state then
  18. error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
  19. end
  20. pcall(load_cudnn)
  21. function w2nn.load_model(model_path, force_cudnn)
  22. local model = torch.load(model_path, "ascii")
  23. if force_cudnn then
  24. model = cudnn.convert(model, cudnn)
  25. end
  26. model:cuda():evaluate()
  27. return model
  28. end
  29. require 'LeakyReLU'
  30. require 'ClippedWeightedHuberCriterion'
  31. require 'ClippedMSECriterion'
  32. require 'SSIMCriterion'
  33. require 'InplaceClip01'
  34. return w2nn
  35. end