AuxiliaryLossCriterion.lua 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. require 'nn'
  2. local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
  3. function AuxiliaryLossCriterion:__init(base_criterion)
  4. parent.__init(self)
  5. self.base_criterion = base_criterion
  6. self.criterions = {}
  7. self.gradInput = {}
  8. self.sizeAverage = false
  9. end
  10. function AuxiliaryLossCriterion:updateOutput(input, target)
  11. local sum_output = 0
  12. if type(input) == "table" then
  13. -- model:training()
  14. for i = 1, #input do
  15. if self.criterions[i] == nil then
  16. self.criterions[i] = self.base_criterion()
  17. self.criterions[i].sizeAverage = self.sizeAverage
  18. if input[i]:type() == "torch.CudaTensor" then
  19. self.criterions[i]:cuda()
  20. end
  21. end
  22. local output = self.criterions[i]:updateOutput(input[i], target)
  23. sum_output = sum_output + output
  24. end
  25. self.output = sum_output / #input
  26. else
  27. -- model:evaluate()
  28. if self.criterions[1] == nil then
  29. self.criterions[1] = self.base_criterion()
  30. self.criterions[1].sizeAverage = self.sizeAverage()
  31. if input:type() == "torch.CudaTensor" then
  32. self.criterions[1]:cuda()
  33. end
  34. end
  35. self.output = self.criterions[1]:updateOutput(input, target)
  36. end
  37. return self.output
  38. end
  39. function AuxiliaryLossCriterion:updateGradInput(input, target)
  40. for i=1,#input do
  41. local gradInput = self.criterions[i]:updateGradInput(input[i], target)
  42. self.gradInput[i] = self.gradInput[i] or gradInput.new()
  43. self.gradInput[i]:resizeAs(gradInput):copy(gradInput)
  44. end
  45. for i=#input+1, #self.gradInput do
  46. self.gradInput[i] = nil
  47. end
  48. return self.gradInput
  49. end