RandomBinaryConvolution.lua 1.1 KB

12345678910111213141516171819202122232425262728
  1. -- RandomBinaryConvolution.lua
  2. -- from https://github.com/juefeix/lbcnn.torch
  3. local THNN = require 'nn.THNN'
  4. local RandomBinaryConvolution, parent = torch.class('w2nn.RandomBinaryConvolution', 'nn.SpatialConvolution')
  5. function RandomBinaryConvolution:__init(nInputPlane, nOutputPlane, kW, kH, kSparsity)
  6. self.kSparsity = kSparsity or 0.9
  7. parent.__init(self, nInputPlane, nOutputPlane, kW, kH, 1, 1, 0, 0)
  8. self:reset()
  9. end
  10. function RandomBinaryConvolution:reset()
  11. local numElements = self.nInputPlane*self.nOutputPlane*self.kW*self.kH
  12. self.weight:fill(0)
  13. self.weight = torch.reshape(self.weight,numElements)
  14. local index = torch.Tensor(torch.floor(self.kSparsity*numElements)):random(numElements)
  15. for i = 1, index:numel() do
  16. self.weight[index[i]] = torch.bernoulli(0.5)*2-1
  17. end
  18. self.weight = torch.reshape(self.weight,self.nOutputPlane,self.nInputPlane,self.kW,self.kH)
  19. self.bias = nil
  20. self.gradBias = nil
  21. self.gradWeight:fill(0)
  22. end
  23. function RandomBinaryConvolution:accGradParameters(input, gradOutput, scale)
  24. end
  25. function RandomBinaryConvolution:updateParameters(learningRate)
  26. end