12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- require 'nn'
- local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
- function AuxiliaryLossCriterion:__init(base_criterion, args)
- parent.__init(self)
- self.base_criterion = base_criterion
- self.args = args
- self.gradInput = {}
- self.sizeAverage = false
- self.criterions = {}
- if self.base_criterion.has_instance_loss then
- self.instance_loss = {}
- end
- end
- function AuxiliaryLossCriterion:updateOutput(input, target)
- local sum_output = 0
- if type(input) == "table" then
- -- model:training()
- self.output = 0
- for i = 1, #input do
- if self.criterions[i] == nil then
- if self.args ~= nil then
- self.criterions[i] = self.base_criterion(table.unpack(self.args))
- else
- self.criterions[i] = self.base_criterion()
- end
- self.criterions[i].sizeAverage = self.sizeAverage
- if input[i]:type() == "torch.CudaTensor" then
- self.criterions[i]:cuda()
- end
- end
- self.output = self.output + self.criterions[i]:updateOutput(input[i], target) / #input
- if self.instance_loss then
- local batch_size = #self.criterions[i].instance_loss
- local scale = 1.0 / #input
- if i == 1 then
- for j = 1, batch_size do
- self.instance_loss[j] = self.criterions[i].instance_loss[j] * scale
- end
- else
- for j = 1, batch_size do
- self.instance_loss[j] = self.instance_loss[j] + self.criterions[i].instance_loss[j] * scale
- end
- end
- end
- end
- else
- -- model:evaluate()
- if self.criterions[1] == nil then
- if self.args ~= nil then
- self.criterions[1] = self.base_criterion(table.unpack(self.args))
- else
- self.criterions[1] = self.base_criterion()
- end
- self.criterions[1].sizeAverage = self.sizeAverage
- if input:type() == "torch.CudaTensor" then
- self.criterions[1]:cuda()
- end
- end
- self.output = self.criterions[1]:updateOutput(input, target)
- if self.instance_loss then
- local batch_size = #self.criterions[1].instance_loss
- for j = 1, batch_size do
- self.instance_loss[j] = self.criterions[1].instance_loss[j]
- end
- end
- end
- return self.output
- end
- function AuxiliaryLossCriterion:updateGradInput(input, target)
- for i=1,#input do
- local gradInput = self.criterions[i]:updateGradInput(input[i], target)
- self.gradInput[i] = self.gradInput[i] or gradInput.new()
- self.gradInput[i]:resizeAs(gradInput):copy(gradInput)
- end
- for i=#input+1, #self.gradInput do
- self.gradInput[i] = nil
- end
- return self.gradInput
- end
|