EdgeFilter.lua 829 B

123456789101112131415161718192021222324252627282930313233
  1. -- EdgeFilter.lua
  2. -- from https://github.com/juefeix/lbcnn.torch
  3. require 'cunn'
  4. local EdgeFilter, parent = torch.class('w2nn.EdgeFilter', 'nn.SpatialConvolution')
  5. function EdgeFilter:__init(nInputPlane)
  6. local output = 0
  7. parent.__init(self, nInputPlane, nInputPlane * 8, 3, 3, 1, 1, 0, 0)
  8. end
  9. function EdgeFilter:reset()
  10. self.bias = nil
  11. self.gradBias = nil
  12. self.gradWeight:fill(0)
  13. self.weight:fill(0)
  14. local fi = 1
  15. -- each channel
  16. for ch = 1, self.nInputPlane do
  17. for i = 0, 8 do
  18. y = math.floor(i / 3) + 1
  19. x = i % 3 + 1
  20. if not (y == 2 and x == 2) then
  21. self.weight[fi][ch][2][2] = 1
  22. self.weight[fi][ch][y][x] = -1
  23. fi = fi + 1
  24. end
  25. end
  26. end
  27. end
  28. function EdgeFilter:accGradParameters(input, gradOutput, scale)
  29. end
  30. function EdgeFilter:updateParameters(learningRate)
  31. end