Browse Source

Add support for AuxiliaryLoss

nagadomi 6 years ago
parent
commit
56536ac133
4 changed files with 104 additions and 1 deletions
  1. 51 0
      lib/AuxiliaryLossCriterion.lua
  2. 40 0
      lib/AuxiliaryLossTable.lua
  3. 3 0
      lib/w2nn.lua
  4. 10 1
      train.lua

+ 51 - 0
lib/AuxiliaryLossCriterion.lua

@@ -0,0 +1,51 @@
+require 'nn'
+local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
+
+function AuxiliaryLossCriterion:__init(base_criterion)
+   parent.__init(self)
+   self.base_criterion = base_criterion
+   self.criterions = {}
+   self.gradInput = {}
+   self.sizeAverage = false
+end
+function AuxiliaryLossCriterion:updateOutput(input, target)
+   local sum_output = 0
+   if type(input) == "table" then
+      -- model:training()
+      for i = 1, #input do
+	 if self.criterions[i] == nil then
+	    self.criterions[i] = self.base_criterion()
+	    self.criterions[i].sizeAverage = self.sizeAverage
+	    if input[i]:type() == "torch.CudaTensor" then
+	       self.criterions[i]:cuda()
+	    end
+	 end
+	 local output = self.criterions[i]:updateOutput(input[i], target)
+	 sum_output = sum_output + output
+      end
+      self.output = sum_output / #input
+   else
+      -- model:evaluate()
+      if self.criterions[1] == nil then
+	 self.criterions[1] = self.base_criterion()
+	 self.criterions[1].sizeAverage = self.sizeAverage()
+	 if input:type() == "torch.CudaTensor" then
+	    self.criterions[1]:cuda()
+	 end
+      end
+      self.output = self.criterions[1]:updateOutput(input, target)
+   end
+   return self.output
+end
+
+function AuxiliaryLossCriterion:updateGradInput(input, target)
+   for i=1,#input do
+      local gradInput = self.criterions[i]:updateGradInput(input[i], target)
+      self.gradInput[i] = self.gradInput[i] or gradInput.new()
+      self.gradInput[i]:resizeAs(gradInput):copy(gradInput)
+   end
+   for i=#input+1, #self.gradInput do
+       self.gradInput[i] = nil
+   end
+   return self.gradInput
+end

+ 40 - 0
lib/AuxiliaryLossTable.lua

@@ -0,0 +1,40 @@
+require 'nn'
+local AuxiliaryLossTable, parent = torch.class('w2nn.AuxiliaryLossTable', 'nn.Module')
+
+function AuxiliaryLossTable:__init(i)
+   parent.__init(self)
+   self.i = i or 1
+   self.gradInput = {}
+   self.output_table = {}
+   self.output_tensor = torch.Tensor()
+end
+
+function AuxiliaryLossTable:updateOutput(input)
+   if self.train then
+      for i=1,#input do
+	 self.output_table[i] = self.output_table[i] or input[1].new()
+	 self.output_table[i]:resizeAs(input[i]):copy(input[i])
+      end
+      for i=#input+1, #self.output_table do
+	 self.output_table[i] = nil
+      end
+      self.output = self.output_table
+   else
+      self.output_tensor:resizeAs(input[1]):copy(input[1])
+      self.output_tensor:copy(input[self.i])
+      self.output = self.output_tensor
+   end
+   return self.output
+end
+
+function AuxiliaryLossTable:updateGradInput(input, gradOutput)
+   for i=1,#input do
+      self.gradInput[i] = self.gradInput[i] or input[1].new()
+      self.gradInput[i]:resizeAs(input[i]):copy(gradOutput[i])
+   end
+   for i=#input+1, #self.gradInput do
+       self.gradInput[i] = nil
+   end
+
+   return self.gradInput
+end

+ 3 - 0
lib/w2nn.lua

@@ -77,5 +77,8 @@ else
    require 'ShakeShakeTable'
    require 'PrintTable'
    require 'Print'
+   require 'AuxiliaryLossTable'
+   require 'AuxiliaryLossCriterion'
    return w2nn
 end
+

+ 10 - 1
train.lua

@@ -365,6 +365,10 @@ local function create_criterion(model)
       local bce = nn.BCECriterion()
       bce.sizeAverage = true
       return bce:cuda()
+   elseif settings.loss == "aux_bce" then
+      local aux = w2nn.AuxiliaryLossCriterion(nn.BCECriterion)
+      aux.sizeAverage = true
+      return aux:cuda()
    else
       error("unsupported loss .." .. settings.loss)
    end
@@ -500,7 +504,12 @@ local function train()
    transform_pool_init(reconstruct.has_resize(model), offset)
 
    local criterion = create_criterion(model)
-   local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
+   local eval_metric = nil
+   if settings.loss:find("aux_") ~= nil then
+      eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion)
+   else
+      eval_metric = w2nn.ClippedMSECriterion()
+   end
    local adam_config = {
       xLearningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,