Selaa lähdekoodia

More clearState for nn.SpatialConvolutionMM

nagadomi 9 vuotta sitten
vanhempi
commit
4a1629d046
1 muutettua tiedostoa jossa 10 lisäystä ja 0 poistoa
  1. 10 0
      lib/srcnn.lua

+ 10 - 0
lib/srcnn.lua

@@ -17,6 +17,16 @@ if cudnn and cudnn.SpatialConvolution then
    end
 end
 
+function nn.SpatialConvolutionMM:clearState()
+   if self.gradWeight then
+      self.gradWeight = torch.Tensor(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):typeAs(self.gradWeight):zero()
+   end
+   if self.gradBias then
+      self.gradBias = torch.Tensor(self.nOutputPlane):typeAs(self.gradBias):zero()
+   end
+   return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
+end
+
 function srcnn.channels(model)
    return model:get(model:size() - 1).weight:size(1)
 end