123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- local function load_nn()
- require 'torch'
- require 'nn'
- end
- local function load_cunn()
- require 'cutorch'
- require 'cunn'
- end
- local function load_cudnn()
- cudnn = require('cudnn')
- end
- local function make_data_parallel_table(model, gpus)
- if cudnn then
- local fastest, benchmark = cudnn.fastest, cudnn.benchmark
- local dpt = nn.DataParallelTable(1, true, true)
- :add(model, gpus)
- :threads(function()
- require 'pl'
- local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
- package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
- require 'torch'
- require 'cunn'
- require 'w2nn'
- local cudnn = require 'cudnn'
- cudnn.fastest, cudnn.benchmark = fastest, benchmark
- end)
- dpt.gradInput = nil
- model = dpt:cuda()
- else
- local dpt = nn.DataParallelTable(1, true, true)
- :add(model, gpus)
- :threads(function()
- require 'pl'
- local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
- package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
- require 'torch'
- require 'cunn'
- require 'w2nn'
- end)
- dpt.gradInput = nil
- model = dpt:cuda()
- end
- return model
- end
- if w2nn then
- return w2nn
- else
- w2nn = {}
- local state, ret = pcall(load_cunn)
- if not state then
- error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
- end
- pcall(load_cudnn)
- function w2nn.load_model(model_path, force_cudnn)
- local model = torch.load(model_path, "ascii")
- if force_cudnn then
- model = cudnn.convert(model, cudnn)
- end
- model:cuda():evaluate()
- return model
- end
- function w2nn.data_parallel(model, gpus)
- if #gpus > 1 then
- return make_data_parallel_table(model, gpus)
- else
- return model
- end
- end
- require 'LeakyReLU'
- require 'ClippedWeightedHuberCriterion'
- require 'ClippedMSECriterion'
- require 'SSIMCriterion'
- require 'InplaceClip01'
- require 'L1Criterion'
- return w2nn
- end
|