AuxiliaryLossCriterion.lua 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. require 'nn'
  2. local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
  3. function AuxiliaryLossCriterion:__init(base_criterion, args)
  4. parent.__init(self)
  5. self.base_criterion = base_criterion
  6. self.args = args
  7. self.criterions = {}
  8. self.gradInput = {}
  9. self.sizeAverage = false
  10. end
  11. function AuxiliaryLossCriterion:updateOutput(input, target)
  12. local sum_output = 0
  13. if type(input) == "table" then
  14. -- model:training()
  15. for i = 1, #input do
  16. if self.criterions[i] == nil then
  17. if self.args ~= nil then
  18. self.criterions[i] = self.base_criterion(table.unpack(self.args))
  19. else
  20. self.criterions[i] = self.base_criterion()
  21. end
  22. self.criterions[i].sizeAverage = self.sizeAverage
  23. if input[i]:type() == "torch.CudaTensor" then
  24. self.criterions[i]:cuda()
  25. end
  26. end
  27. local output = self.criterions[i]:updateOutput(input[i], target)
  28. sum_output = sum_output + output
  29. end
  30. self.output = sum_output / #input
  31. else
  32. -- model:evaluate()
  33. if self.criterions[1] == nil then
  34. if self.args ~= nil then
  35. self.criterions[1] = self.base_criterion(table.unpack(self.args))
  36. else
  37. self.criterions[1] = self.base_criterion()
  38. end
  39. self.criterions[1].sizeAverage = self.sizeAverage
  40. if input:type() == "torch.CudaTensor" then
  41. self.criterions[1]:cuda()
  42. end
  43. end
  44. self.output = self.criterions[1]:updateOutput(input, target)
  45. end
  46. return self.output
  47. end
  48. function AuxiliaryLossCriterion:updateGradInput(input, target)
  49. for i=1,#input do
  50. local gradInput = self.criterions[i]:updateGradInput(input[i], target)
  51. self.gradInput[i] = self.gradInput[i] or gradInput.new()
  52. self.gradInput[i]:resizeAs(gradInput):copy(gradInput)
  53. end
  54. for i=#input+1, #self.gradInput do
  55. self.gradInput[i] = nil
  56. end
  57. return self.gradInput
  58. end