|
@@ -1,9 +1,10 @@
|
|
require 'nn'
|
|
require 'nn'
|
|
local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
|
|
local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
|
|
|
|
|
|
-function AuxiliaryLossCriterion:__init(base_criterion)
|
|
|
|
|
|
+function AuxiliaryLossCriterion:__init(base_criterion, args)
|
|
parent.__init(self)
|
|
parent.__init(self)
|
|
self.base_criterion = base_criterion
|
|
self.base_criterion = base_criterion
|
|
|
|
+ self.args = args
|
|
self.criterions = {}
|
|
self.criterions = {}
|
|
self.gradInput = {}
|
|
self.gradInput = {}
|
|
self.sizeAverage = false
|
|
self.sizeAverage = false
|
|
@@ -14,7 +15,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
|
-- model:training()
|
|
-- model:training()
|
|
for i = 1, #input do
|
|
for i = 1, #input do
|
|
if self.criterions[i] == nil then
|
|
if self.criterions[i] == nil then
|
|
- self.criterions[i] = self.base_criterion()
|
|
|
|
|
|
+ if self.args ~= nil then
|
|
|
|
+ self.criterions[i] = self.base_criterion(table.unpack(self.args))
|
|
|
|
+ else
|
|
|
|
+ self.criterions[i] = self.base_criterion()
|
|
|
|
+ end
|
|
self.criterions[i].sizeAverage = self.sizeAverage
|
|
self.criterions[i].sizeAverage = self.sizeAverage
|
|
if input[i]:type() == "torch.CudaTensor" then
|
|
if input[i]:type() == "torch.CudaTensor" then
|
|
self.criterions[i]:cuda()
|
|
self.criterions[i]:cuda()
|
|
@@ -27,7 +32,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
|
else
|
|
else
|
|
-- model:evaluate()
|
|
-- model:evaluate()
|
|
if self.criterions[1] == nil then
|
|
if self.criterions[1] == nil then
|
|
- self.criterions[1] = self.base_criterion()
|
|
|
|
|
|
+ if self.args ~= nil then
|
|
|
|
+ self.criterions[1] = self.base_criterion(table.unpack(self.args))
|
|
|
|
+ else
|
|
|
|
+ self.criterions[1] = self.base_criterion()
|
|
|
|
+ end
|
|
self.criterions[1].sizeAverage = self.sizeAverage()
|
|
self.criterions[1].sizeAverage = self.sizeAverage()
|
|
if input:type() == "torch.CudaTensor" then
|
|
if input:type() == "torch.CudaTensor" then
|
|
self.criterions[1]:cuda()
|
|
self.criterions[1]:cuda()
|