浏览代码

Add support for criterion arguments in AuxiliaryLossCriterion

nagadomi 6 年之前
父节点
当前提交
aef969d64b
共有 1 个文件被更改,包括 12 次插入3 次删除
  1. 12 3
      lib/AuxiliaryLossCriterion.lua

+ 12 - 3
lib/AuxiliaryLossCriterion.lua

@@ -1,9 +1,10 @@
 require 'nn'
 local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
 
-function AuxiliaryLossCriterion:__init(base_criterion)
+function AuxiliaryLossCriterion:__init(base_criterion, args)
    parent.__init(self)
    self.base_criterion = base_criterion
+   self.args = args
    self.criterions = {}
    self.gradInput = {}
    self.sizeAverage = false
@@ -14,7 +15,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
       -- model:training()
       for i = 1, #input do
 	 if self.criterions[i] == nil then
-	    self.criterions[i] = self.base_criterion()
+	    if self.args ~= nil then
+	       self.criterions[i] = self.base_criterion(table.unpack(self.args))
+	    else
+	       self.criterions[i] = self.base_criterion()
+	    end
 	    self.criterions[i].sizeAverage = self.sizeAverage
 	    if input[i]:type() == "torch.CudaTensor" then
 	       self.criterions[i]:cuda()
@@ -27,7 +32,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
    else
       -- model:evaluate()
       if self.criterions[1] == nil then
-	 self.criterions[1] = self.base_criterion()
+	 if self.args ~= nil then
+	    self.criterions[1] = self.base_criterion(table.unpack(self.args))
+	 else
+	    self.criterions[1] = self.base_criterion()
+	 end
 	 self.criterions[1].sizeAverage = self.sizeAverage()
 	 if input:type() == "torch.CudaTensor" then
 	    self.criterions[1]:cuda()