Ver código fonte

Add support for identity initializer for dilated convolution, and refactor

nagadomi 8 anos atrás
pai
commit
05bc54fa12
1 arquivos alterados com 67 adições e 20 exclusões
  1. 67 20
      lib/srcnn.lua

+ 67 - 20
lib/srcnn.lua

@@ -4,34 +4,52 @@ require 'w2nn'
 -- ref: http://arxiv.org/abs/1501.00092
 local srcnn = {}
 
-function nn.SpatialConvolutionMM:reset(stdv)
-   local fin = self.kW * self.kH * self.nInputPlane
-   local fout = self.kW * self.kH * self.nOutputPlane
+local function msra_filler(mod)
+   local fin = mod.kW * mod.kH * mod.nInputPlane
+   local fout = mod.kW * mod.kH * mod.nOutputPlane
    stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
-   self.weight:normal(0, stdv)
-   self.bias:zero()
+   mod.weight:normal(0, stdv)
+   mod.bias:zero()
+end
+local function identity_filler(mod)
+   assert(mod.nInputPlane <= mod.nOutputPlane)
+   mod.weight:normal(0, 0.01)
+   mod.bias:zero()
+   local num_groups = mod.nInputPlane -- fixed
+   local filler_value = num_groups / mod.nOutputPlane
+   local in_group_size = math.floor(mod.nInputPlane / num_groups)
+   local out_group_size = math.floor(mod.nOutputPlane / num_groups)
+   local x = math.floor(mod.kW / 2)
+   local y = math.floor(mod.kH / 2)
+   for i = 0, num_groups - 1 do
+      for j = i * out_group_size, (i + 1) * out_group_size - 1 do
+	 for k = i * in_group_size, (i + 1) * in_group_size - 1 do
+	    mod.weight[j+1][k+1][y+1][x+1] = filler_value
+	 end
+      end
+   end
+end
+function nn.SpatialConvolutionMM:reset(stdv)
+   msra_filler(self)
 end
 function nn.SpatialFullConvolution:reset(stdv)
-   local fin = self.kW * self.kH * self.nInputPlane
-   local fout = self.kW * self.kH * self.nOutputPlane
-   stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
-   self.weight:normal(0, stdv)
-   self.bias:zero()
+   msra_filler(self)
 end
+function nn.SpatialDilatedConvolution:reset(stdv)
+   identity_filler(self)
+end
+
 if cudnn and cudnn.SpatialConvolution then
    function cudnn.SpatialConvolution:reset(stdv)
-      local fin = self.kW * self.kH * self.nInputPlane
-      local fout = self.kW * self.kH * self.nOutputPlane
-      stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
-      self.weight:normal(0, stdv)
-      self.bias:zero()
+      msra_filler(self)
    end
    function cudnn.SpatialFullConvolution:reset(stdv)
-      local fin = self.kW * self.kH * self.nInputPlane
-      local fout = self.kW * self.kH * self.nOutputPlane
-      stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
-      self.weight:normal(0, stdv)
-      self.bias:zero()
+      msra_filler(self)
+   end
+   if cudnn.SpatialDilatedConvolution then
+      function cudnn.SpatialDilatedConvolution:reset(stdv)
+	 identity_filler(self)
+      end
    end
 end
 function nn.SpatialConvolutionMM:clearState()
@@ -162,6 +180,34 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
 end
 srcnn.SpatialMaxPooling = SpatialMaxPooling
 
+local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
+   if backend == "cunn" then
+      return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
+   elseif backend == "cudnn" then
+      return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
+   else
+      error("unsupported backend:" .. backend)
+   end
+end
+srcnn.SpatialAveragePooling = SpatialAveragePooling
+
+local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)      
+   if backend == "cunn" then
+      return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
+   elseif backend == "cudnn" then
+      if cudnn.SpatialDilatedConvolution then
+	 -- cudnn v 6
+	 return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
+      else
+	 return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
+      end
+   else
+      error("unsupported backend:" .. backend)
+   end
+end
+srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
+
+
 -- VGG style net(7 layers)
 function srcnn.vgg_7(backend, ch)
    local model = nn.Sequential()
@@ -555,6 +601,7 @@ function srcnn.create(model_name, backend, color)
       error("unsupported model_name: " .. model_name)
    end
 end
+
 --[[
 local model = srcnn.fcn_v1("cunn", 3):cuda()
 print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())