|
@@ -17,6 +17,16 @@ if cudnn and cudnn.SpatialConvolution then
|
|
|
end
|
|
|
end
|
|
|
|
|
|
+function nn.SpatialConvolutionMM:clearState()
|
|
|
+ if self.gradWeight then
|
|
|
+ self.gradWeight = torch.Tensor(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):typeAs(self.gradWeight):zero()
|
|
|
+ end
|
|
|
+ if self.gradBias then
|
|
|
+ self.gradBias = torch.Tensor(self.nOutputPlane):typeAs(self.gradBias):zero()
|
|
|
+ end
|
|
|
+ return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
|
|
|
+end
|
|
|
+
|
|
|
function srcnn.channels(model)
|
|
|
return model:get(model:size() - 1).weight:size(1)
|
|
|
end
|