AuxiliaryLossCriterion.lua 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. require 'nn'
  2. local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
  3. function AuxiliaryLossCriterion:__init(base_criterion, args)
  4. parent.__init(self)
  5. self.base_criterion = base_criterion
  6. self.args = args
  7. self.gradInput = {}
  8. self.sizeAverage = false
  9. self.criterions = {}
  10. if self.base_criterion.has_instance_loss then
  11. self.instance_loss = {}
  12. end
  13. end
  14. function AuxiliaryLossCriterion:updateOutput(input, target)
  15. local sum_output = 0
  16. if type(input) == "table" then
  17. -- model:training()
  18. self.output = 0
  19. for i = 1, #input do
  20. if self.criterions[i] == nil then
  21. if self.args ~= nil then
  22. self.criterions[i] = self.base_criterion(table.unpack(self.args))
  23. else
  24. self.criterions[i] = self.base_criterion()
  25. end
  26. self.criterions[i].sizeAverage = self.sizeAverage
  27. if input[i]:type() == "torch.CudaTensor" then
  28. self.criterions[i]:cuda()
  29. end
  30. end
  31. self.output = self.output + self.criterions[i]:updateOutput(input[i], target) / #input
  32. if self.instance_loss then
  33. local batch_size = #self.criterions[i].instance_loss
  34. local scale = 1.0 / #input
  35. if i == 1 then
  36. for j = 1, batch_size do
  37. self.instance_loss[j] = self.criterions[i].instance_loss[j] * scale
  38. end
  39. else
  40. for j = 1, batch_size do
  41. self.instance_loss[j] = self.instance_loss[j] + self.criterions[i].instance_loss[j] * scale
  42. end
  43. end
  44. end
  45. end
  46. else
  47. -- model:evaluate()
  48. if self.criterions[1] == nil then
  49. if self.args ~= nil then
  50. self.criterions[1] = self.base_criterion(table.unpack(self.args))
  51. else
  52. self.criterions[1] = self.base_criterion()
  53. end
  54. self.criterions[1].sizeAverage = self.sizeAverage
  55. if input:type() == "torch.CudaTensor" then
  56. self.criterions[1]:cuda()
  57. end
  58. end
  59. self.output = self.criterions[1]:updateOutput(input, target)
  60. if self.instance_loss then
  61. local batch_size = #self.criterions[1].instance_loss
  62. for j = 1, batch_size do
  63. self.instance_loss[j] = self.criterions[1].instance_loss[j]
  64. end
  65. end
  66. end
  67. return self.output
  68. end
  69. function AuxiliaryLossCriterion:updateGradInput(input, target)
  70. for i=1,#input do
  71. local gradInput = self.criterions[i]:updateGradInput(input[i], target)
  72. self.gradInput[i] = self.gradInput[i] or gradInput.new()
  73. self.gradInput[i]:resizeAs(gradInput):copy(gradInput)
  74. end
  75. for i=#input+1, #self.gradInput do
  76. self.gradInput[i] = nil
  77. end
  78. return self.gradInput
  79. end