LeakyReLU.lua 923 B

12345678910111213141516171819202122232425262728293031
  1. if w2nn and w2nn.LeakyReLU then
  2. return w2nn.LeakyReLU
  3. end
  4. local LeakyReLU, parent = torch.class('w2nn.LeakyReLU','nn.Module')
  5. function LeakyReLU:__init(negative_scale)
  6. parent.__init(self)
  7. self.negative_scale = negative_scale or 0.333
  8. self.negative = torch.Tensor()
  9. end
  10. function LeakyReLU:updateOutput(input)
  11. self.output:resizeAs(input):copy(input):abs():add(input):div(2)
  12. self.negative:resizeAs(input):copy(input):abs():add(-1.0, input):mul(-0.5*self.negative_scale)
  13. self.output:add(self.negative)
  14. return self.output
  15. end
  16. function LeakyReLU:updateGradInput(input, gradOutput)
  17. self.gradInput:resizeAs(gradOutput)
  18. -- filter positive
  19. self.negative:sign():add(1)
  20. torch.cmul(self.gradInput, gradOutput, self.negative)
  21. -- filter negative
  22. self.negative:add(-1):mul(-1 * self.negative_scale):cmul(gradOutput)
  23. self.gradInput:add(self.negative)
  24. return self.gradInput
  25. end