w2nn.lua 740 B

1234567891011121314151617181920212223242526272829303132333435
  1. local function load_nn()
  2. require 'torch'
  3. require 'nn'
  4. end
  5. local function load_cunn()
  6. require 'cutorch'
  7. require 'cunn'
  8. end
  9. local function load_cudnn()
  10. require 'cudnn'
  11. cudnn.benchmark = true
  12. end
  13. if w2nn then
  14. return w2nn
  15. else
  16. pcall(load_cunn)
  17. pcall(load_cudnn)
  18. w2nn = {}
  19. function w2nn.load_model(model_path, force_cudnn)
  20. local model = torch.load(model_path, "ascii")
  21. if force_cudnn then
  22. model = cudnn.convert(model, cudnn)
  23. end
  24. model:cuda():evaluate()
  25. return model
  26. end
  27. require 'LeakyReLU'
  28. require 'LeakyReLU_deprecated'
  29. require 'DepthExpand2x'
  30. require 'PSNRCriterion'
  31. require 'ClippedWeightedHuberCriterion'
  32. require 'ClippedMSECriterion'
  33. return w2nn
  34. end