|
@@ -5,14 +5,18 @@ function AuxiliaryLossCriterion:__init(base_criterion, args)
|
|
|
parent.__init(self)
|
|
|
self.base_criterion = base_criterion
|
|
|
self.args = args
|
|
|
- self.criterions = {}
|
|
|
self.gradInput = {}
|
|
|
self.sizeAverage = false
|
|
|
+ self.criterions = {}
|
|
|
+ if self.base_criterion.has_instance_loss then
|
|
|
+ self.instance_loss = {}
|
|
|
+ end
|
|
|
end
|
|
|
function AuxiliaryLossCriterion:updateOutput(input, target)
|
|
|
local sum_output = 0
|
|
|
if type(input) == "table" then
|
|
|
-- model:training()
|
|
|
+ self.output = 0
|
|
|
for i = 1, #input do
|
|
|
if self.criterions[i] == nil then
|
|
|
if self.args ~= nil then
|
|
@@ -25,10 +29,22 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
|
|
self.criterions[i]:cuda()
|
|
|
end
|
|
|
end
|
|
|
- local output = self.criterions[i]:updateOutput(input[i], target)
|
|
|
- sum_output = sum_output + output
|
|
|
+ self.output = self.output + self.criterions[i]:updateOutput(input[i], target) / #input
|
|
|
+
|
|
|
+ if self.instance_loss then
|
|
|
+ local batch_size = #self.criterions[i].instance_loss
|
|
|
+ local scale = 1.0 / #input
|
|
|
+ if i == 1 then
|
|
|
+ for j = 1, batch_size do
|
|
|
+ self.instance_loss[j] = self.criterions[i].instance_loss[j] * scale
|
|
|
+ end
|
|
|
+ else
|
|
|
+ for j = 1, batch_size do
|
|
|
+ self.instance_loss[j] = self.instance_loss[j] + self.criterions[i].instance_loss[j] * scale
|
|
|
+ end
|
|
|
+ end
|
|
|
+ end
|
|
|
end
|
|
|
- self.output = sum_output / #input
|
|
|
else
|
|
|
-- model:evaluate()
|
|
|
if self.criterions[1] == nil then
|
|
@@ -43,6 +59,12 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
|
|
end
|
|
|
end
|
|
|
self.output = self.criterions[1]:updateOutput(input, target)
|
|
|
+ if self.instance_loss then
|
|
|
+ local batch_size = #self.criterions[1].instance_loss
|
|
|
+ for j = 1, batch_size do
|
|
|
+ self.instance_loss[j] = self.criterions[1].instance_loss[j]
|
|
|
+ end
|
|
|
+ end
|
|
|
end
|
|
|
return self.output
|
|
|
end
|