w2nn.lua 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. local function load_nn()
  2. require 'torch'
  3. require 'nn'
  4. end
  5. local function load_cunn()
  6. require 'cutorch'
  7. require 'cunn'
  8. end
  9. local function load_cudnn()
  10. cudnn = require('cudnn')
  11. end
  12. local function make_data_parallel_table(model, gpus)
  13. if cudnn then
  14. local fastest, benchmark = cudnn.fastest, cudnn.benchmark
  15. local dpt = nn.DataParallelTable(1, true, true)
  16. :add(model, gpus)
  17. :threads(function()
  18. require 'pl'
  19. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  20. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  21. require 'torch'
  22. require 'cunn'
  23. require 'w2nn'
  24. local cudnn = require 'cudnn'
  25. cudnn.fastest, cudnn.benchmark = fastest, benchmark
  26. end)
  27. dpt.gradInput = nil
  28. model = dpt:cuda()
  29. else
  30. local dpt = nn.DataParallelTable(1, true, true)
  31. :add(model, gpus)
  32. :threads(function()
  33. require 'pl'
  34. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  35. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  36. require 'torch'
  37. require 'cunn'
  38. require 'w2nn'
  39. end)
  40. dpt.gradInput = nil
  41. model = dpt:cuda()
  42. end
  43. return model
  44. end
  45. if w2nn then
  46. return w2nn
  47. else
  48. w2nn = {}
  49. local state, ret = pcall(load_cunn)
  50. if not state then
  51. error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
  52. end
  53. pcall(load_cudnn)
  54. function w2nn.load_model(model_path, force_cudnn)
  55. local model = torch.load(model_path, "ascii")
  56. if force_cudnn then
  57. model = cudnn.convert(model, cudnn)
  58. end
  59. model:cuda():evaluate()
  60. return model
  61. end
  62. function w2nn.data_parallel(model, gpus)
  63. if #gpus > 1 then
  64. return make_data_parallel_table(model, gpus)
  65. else
  66. return model
  67. end
  68. end
  69. require 'LeakyReLU'
  70. require 'ClippedWeightedHuberCriterion'
  71. require 'ClippedMSECriterion'
  72. require 'SSIMCriterion'
  73. require 'InplaceClip01'
  74. require 'L1Criterion'
  75. return w2nn
  76. end