Print.lua 416 B

1234567891011121314151617
  1. local Print, parent = torch.class('w2nn.Print','nn.Module')
  2. function Print:__init(id)
  3. parent.__init(self)
  4. self.id = id
  5. end
  6. function Print:updateOutput(input)
  7. print("----", self.id)
  8. print(input:size())
  9. self.output:resizeAs(input)
  10. self.output:copy(input)
  11. return self.output
  12. end
  13. function Print:updateGradInput(input, gradOutput)
  14. self.gradInput:resizeAs(GradOutput)
  15. return self.gradInput
  16. end