|
@@ -3,6 +3,20 @@ require 'w2nn'
|
|
|
-- ref: http://arxiv.org/abs/1502.01852
|
|
|
-- ref: http://arxiv.org/abs/1501.00092
|
|
|
local srcnn = {}
|
|
|
+
|
|
|
+function nn.SpatialConvolutionMM:reset(stdv)
|
|
|
+ stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
|
|
|
+ self.weight:normal(0, stdv)
|
|
|
+ self.bias:zero()
|
|
|
+end
|
|
|
+if cudnn then
|
|
|
+ function cudnn.SpatialConvolution:reset(stdv)
|
|
|
+ stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
|
|
|
+ self.weight:normal(0, stdv)
|
|
|
+ self.bias:zero()
|
|
|
+ end
|
|
|
+end
|
|
|
+
|
|
|
function srcnn.channels(model)
|
|
|
return model:get(model:size() - 1).weight:size(1)
|
|
|
end
|