1234567891011121314151617181920212223242526272829303132333435 |
- local function load_nn()
- require 'torch'
- require 'nn'
- end
- local function load_cunn()
- require 'cutorch'
- require 'cunn'
- end
- local function load_cudnn()
- require 'cudnn'
- cudnn.benchmark = true
- end
- if w2nn then
- return w2nn
- else
- pcall(load_cunn)
- pcall(load_cudnn)
- w2nn = {}
- function w2nn.load_model(model_path, force_cudnn)
- local model = torch.load(model_path, "ascii")
- if force_cudnn then
- model = cudnn.convert(model, cudnn)
- end
- model:cuda():evaluate()
- return model
- end
- require 'LeakyReLU'
- require 'LeakyReLU_deprecated'
- require 'DepthExpand2x'
- require 'PSNRCriterion'
- require 'ClippedWeightedHuberCriterion'
- require 'ClippedMSECriterion'
- return w2nn
- end
|