EdgeFilter.lua 764 B

12345678910111213141516171819202122232425262728293031
  1. require 'cunn'
  2. local EdgeFilter, parent = torch.class('w2nn.EdgeFilter', 'nn.SpatialConvolution')
  3. function EdgeFilter:__init(nInputPlane)
  4. local output = 0
  5. parent.__init(self, nInputPlane, nInputPlane * 8, 3, 3, 1, 1, 0, 0)
  6. end
  7. function EdgeFilter:reset()
  8. self.bias = nil
  9. self.gradBias = nil
  10. self.gradWeight:fill(0)
  11. self.weight:fill(0)
  12. local fi = 1
  13. -- each channel
  14. for ch = 1, self.nInputPlane do
  15. for i = 0, 8 do
  16. y = math.floor(i / 3) + 1
  17. x = i % 3 + 1
  18. if not (y == 2 and x == 2) then
  19. self.weight[fi][ch][2][2] = 1
  20. self.weight[fi][ch][y][x] = -1
  21. fi = fi + 1
  22. end
  23. end
  24. end
  25. end
  26. function EdgeFilter:accGradParameters(input, gradOutput, scale)
  27. end
  28. function EdgeFilter:updateParameters(learningRate)
  29. end