GradWeight.lua 576 B

1234567891011121314151617181920
  1. local GradWeight, parent = torch.class('w2nn.GradWeight', 'nn.Module')
  2. function GradWeight:__init(constant_scalar)
  3. parent.__init(self)
  4. assert(type(constant_scalar) == 'number', 'input is not scalar!')
  5. self.constant_scalar = constant_scalar
  6. end
  7. function GradWeight:updateOutput(input)
  8. self.output:resizeAs(input)
  9. self.output:copy(input)
  10. return self.output
  11. end
  12. function GradWeight:updateGradInput(input, gradOutput)
  13. self.gradInput:resizeAs(gradOutput)
  14. self.gradInput:copy(gradOutput)
  15. self.gradInput:mul(self.constant_scalar)
  16. return self.gradInput
  17. end