|
@@ -9,8 +9,8 @@ require 'cudnn'
|
|
require 'w2nn'
|
|
require 'w2nn'
|
|
local srcnn = require 'srcnn'
|
|
local srcnn = require 'srcnn'
|
|
|
|
|
|
-local function cudnn2cunn(cunn_model)
|
|
|
|
- local cudnn_model = srcnn.waifu2x_cunn(srcnn.channels(cunn_model))
|
|
|
|
|
|
+local function cudnn2cunn(cudnn_model)
|
|
|
|
+ local cunn_model = srcnn.waifu2x_cunn(srcnn.channels(cudnn_model))
|
|
local weight_from = cudnn_model:findModules("cudnn.SpatialConvolution")
|
|
local weight_from = cudnn_model:findModules("cudnn.SpatialConvolution")
|
|
local weight_to = cunn_model:findModules("nn.SpatialConvolutionMM")
|
|
local weight_to = cunn_model:findModules("nn.SpatialConvolutionMM")
|
|
|
|
|