AuxiliaryLossTable.lua 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. require 'nn'
  2. local AuxiliaryLossTable, parent = torch.class('w2nn.AuxiliaryLossTable', 'nn.Module')
  3. function AuxiliaryLossTable:__init(i)
  4. parent.__init(self)
  5. self.i = i or 1
  6. self.gradInput = {}
  7. self.output_table = {}
  8. self.output_tensor = torch.Tensor()
  9. end
  10. function AuxiliaryLossTable:updateOutput(input)
  11. if self.train then
  12. for i=1,#input do
  13. self.output_table[i] = self.output_table[i] or input[1].new()
  14. self.output_table[i]:resizeAs(input[i]):copy(input[i])
  15. end
  16. for i=#input+1, #self.output_table do
  17. self.output_table[i] = nil
  18. end
  19. self.output = self.output_table
  20. else
  21. self.output_tensor:resizeAs(input[1]):copy(input[1])
  22. self.output_tensor:copy(input[self.i])
  23. self.output = self.output_tensor
  24. end
  25. return self.output
  26. end
  27. function AuxiliaryLossTable:updateGradInput(input, gradOutput)
  28. for i=1,#input do
  29. self.gradInput[i] = self.gradInput[i] or input[1].new()
  30. self.gradInput[i]:resizeAs(input[i]):copy(gradOutput[i])
  31. end
  32. for i=#input+1, #self.gradInput do
  33. self.gradInput[i] = nil
  34. end
  35. return self.gradInput
  36. end
  37. function AuxiliaryLossTable:clearState()
  38. self.gradInput = {}
  39. self.output_table = {}
  40. self.output_tensor:set()
  41. return parent:clearState()
  42. end