|
@@ -4,34 +4,52 @@ require 'w2nn'
|
|
|
-- ref: http://arxiv.org/abs/1501.00092
|
|
|
local srcnn = {}
|
|
|
|
|
|
-function nn.SpatialConvolutionMM:reset(stdv)
|
|
|
- local fin = self.kW * self.kH * self.nInputPlane
|
|
|
- local fout = self.kW * self.kH * self.nOutputPlane
|
|
|
+local function msra_filler(mod)
|
|
|
+ local fin = mod.kW * mod.kH * mod.nInputPlane
|
|
|
+ local fout = mod.kW * mod.kH * mod.nOutputPlane
|
|
|
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
|
|
- self.weight:normal(0, stdv)
|
|
|
- self.bias:zero()
|
|
|
+ mod.weight:normal(0, stdv)
|
|
|
+ mod.bias:zero()
|
|
|
+end
|
|
|
+local function identity_filler(mod)
|
|
|
+ assert(mod.nInputPlane <= mod.nOutputPlane)
|
|
|
+ mod.weight:normal(0, 0.01)
|
|
|
+ mod.bias:zero()
|
|
|
+ local num_groups = mod.nInputPlane -- fixed
|
|
|
+ local filler_value = num_groups / mod.nOutputPlane
|
|
|
+ local in_group_size = math.floor(mod.nInputPlane / num_groups)
|
|
|
+ local out_group_size = math.floor(mod.nOutputPlane / num_groups)
|
|
|
+ local x = math.floor(mod.kW / 2)
|
|
|
+ local y = math.floor(mod.kH / 2)
|
|
|
+ for i = 0, num_groups - 1 do
|
|
|
+ for j = i * out_group_size, (i + 1) * out_group_size - 1 do
|
|
|
+ for k = i * in_group_size, (i + 1) * in_group_size - 1 do
|
|
|
+ mod.weight[j+1][k+1][y+1][x+1] = filler_value
|
|
|
+ end
|
|
|
+ end
|
|
|
+ end
|
|
|
+end
|
|
|
+function nn.SpatialConvolutionMM:reset(stdv)
|
|
|
+ msra_filler(self)
|
|
|
end
|
|
|
function nn.SpatialFullConvolution:reset(stdv)
|
|
|
- local fin = self.kW * self.kH * self.nInputPlane
|
|
|
- local fout = self.kW * self.kH * self.nOutputPlane
|
|
|
- stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
|
|
- self.weight:normal(0, stdv)
|
|
|
- self.bias:zero()
|
|
|
+ msra_filler(self)
|
|
|
end
|
|
|
+function nn.SpatialDilatedConvolution:reset(stdv)
|
|
|
+ identity_filler(self)
|
|
|
+end
|
|
|
+
|
|
|
if cudnn and cudnn.SpatialConvolution then
|
|
|
function cudnn.SpatialConvolution:reset(stdv)
|
|
|
- local fin = self.kW * self.kH * self.nInputPlane
|
|
|
- local fout = self.kW * self.kH * self.nOutputPlane
|
|
|
- stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
|
|
- self.weight:normal(0, stdv)
|
|
|
- self.bias:zero()
|
|
|
+ msra_filler(self)
|
|
|
end
|
|
|
function cudnn.SpatialFullConvolution:reset(stdv)
|
|
|
- local fin = self.kW * self.kH * self.nInputPlane
|
|
|
- local fout = self.kW * self.kH * self.nOutputPlane
|
|
|
- stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
|
|
- self.weight:normal(0, stdv)
|
|
|
- self.bias:zero()
|
|
|
+ msra_filler(self)
|
|
|
+ end
|
|
|
+ if cudnn.SpatialDilatedConvolution then
|
|
|
+ function cudnn.SpatialDilatedConvolution:reset(stdv)
|
|
|
+ identity_filler(self)
|
|
|
+ end
|
|
|
end
|
|
|
end
|
|
|
function nn.SpatialConvolutionMM:clearState()
|
|
@@ -162,6 +180,34 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
|
|
end
|
|
|
srcnn.SpatialMaxPooling = SpatialMaxPooling
|
|
|
|
|
|
+local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
|
|
|
+ if backend == "cunn" then
|
|
|
+ return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
|
|
+ elseif backend == "cudnn" then
|
|
|
+ return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
|
|
+ else
|
|
|
+ error("unsupported backend:" .. backend)
|
|
|
+ end
|
|
|
+end
|
|
|
+srcnn.SpatialAveragePooling = SpatialAveragePooling
|
|
|
+
|
|
|
+local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
|
|
+ if backend == "cunn" then
|
|
|
+ return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
|
|
+ elseif backend == "cudnn" then
|
|
|
+ if cudnn.SpatialDilatedConvolution then
|
|
|
+ -- cudnn v 6
|
|
|
+ return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
|
|
+ else
|
|
|
+ return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
|
|
+ end
|
|
|
+ else
|
|
|
+ error("unsupported backend:" .. backend)
|
|
|
+ end
|
|
|
+end
|
|
|
+srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
|
|
|
+
|
|
|
+
|
|
|
-- VGG style net(7 layers)
|
|
|
function srcnn.vgg_7(backend, ch)
|
|
|
local model = nn.Sequential()
|
|
@@ -555,6 +601,7 @@ function srcnn.create(model_name, backend, color)
|
|
|
error("unsupported model_name: " .. model_name)
|
|
|
end
|
|
|
end
|
|
|
+
|
|
|
--[[
|
|
|
local model = srcnn.fcn_v1("cunn", 3):cuda()
|
|
|
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
|