nagadomi 6 rokov pred
rodič
commit
86d8fe96da

+ 26 - 4
lib/AuxiliaryLossCriterion.lua

@@ -5,14 +5,18 @@ function AuxiliaryLossCriterion:__init(base_criterion, args)
    parent.__init(self)
    self.base_criterion = base_criterion
    self.args = args
-   self.criterions = {}
    self.gradInput = {}
    self.sizeAverage = false
+   self.criterions = {}
+   if self.base_criterion.has_instance_loss then
+      self.instance_loss = {}
+   end
 end
 function AuxiliaryLossCriterion:updateOutput(input, target)
    local sum_output = 0
    if type(input) == "table" then
       -- model:training()
+      self.output = 0
       for i = 1, #input do
 	 if self.criterions[i] == nil then
 	    if self.args ~= nil then
@@ -25,10 +29,22 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
 	       self.criterions[i]:cuda()
 	    end
 	 end
-	 local output = self.criterions[i]:updateOutput(input[i], target)
-	 sum_output = sum_output + output
+	 self.output = self.output + self.criterions[i]:updateOutput(input[i], target) / #input
+
+	 if self.instance_loss then
+	    local batch_size = #self.criterions[i].instance_loss
+	    local scale = 1.0 / #input
+	    if i == 1 then
+	       for j = 1, batch_size do
+		  self.instance_loss[j] = self.criterions[i].instance_loss[j] * scale
+	       end
+	    else
+	       for j = 1, batch_size do
+		  self.instance_loss[j] = self.instance_loss[j] + self.criterions[i].instance_loss[j] * scale
+	       end
+	    end
+	 end
       end
-      self.output = sum_output / #input
    else
       -- model:evaluate()
       if self.criterions[1] == nil then
@@ -43,6 +59,12 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
 	 end
       end
       self.output = self.criterions[1]:updateOutput(input, target)
+      if self.instance_loss then
+	 local batch_size = #self.criterions[1].instance_loss
+	 for j = 1, batch_size do
+	    self.instance_loss[j] = self.criterions[1].instance_loss[j]
+	 end
+      end
    end
    return self.output
 end

+ 11 - 2
lib/ClippedMSECriterion.lua

@@ -1,19 +1,28 @@
 local ClippedMSECriterion, parent = torch.class('w2nn.ClippedMSECriterion','nn.Criterion')
 
+ClippedMSECriterion.has_instance_loss = true
 function ClippedMSECriterion:__init(min, max)
    parent.__init(self)
    self.min = min or 0
    self.max = max or 1
    self.diff = torch.Tensor()
    self.diff_pow2 = torch.Tensor()
+   self.instance_loss = {}
 end
 function ClippedMSECriterion:updateOutput(input, target)
    self.diff:resizeAs(input):copy(input)
    self.diff:clamp(self.min, self.max)
    self.diff:add(-1, target)
    self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
-   self.output = self.diff_pow2:sum() / input:nElement()
-   return self.output
+   self.instance_loss = {}
+   self.output = 0
+   local scale = 1.0 / input:size(1)
+   for i = 1, input:size(1) do
+      local instance_loss = self.diff_pow2[i]:sum() / self.diff_pow2[i]:nElement()
+      self.instance_loss[i] = instance_loss
+      self.output = self.output + instance_loss
+   end
+   return self.output / input:size(1)
 end
 function ClippedMSECriterion:updateGradInput(input, target)
    local norm = 1.0 / input:nElement()

+ 1 - 3
lib/LBPCriterion.lua

@@ -55,9 +55,7 @@ function LBPCriterion:updateOutput(input, target)
 
    -- huber loss
    self.diff:resizeAs(lb1):copy(lb1)
-   for i = 1, lb1:size(1) do
-      self.diff[i]:add(-1, lb2[i])
-   end
+   self.diff:add(-1, lb2)
    self.diff_abs:resizeAs(self.diff):copy(self.diff):abs()
    
    local square_targets = self.diff[torch.lt(self.diff_abs, self.gamma)]

+ 5 - 22
lib/minibatch_adam.lua

@@ -44,29 +44,12 @@ local function minibatch_adam(model, criterion, eval_metric,
 	 gradParameters:zero()
 	 local output = model:forward(inputs)
 	 local f = criterion:forward(output, targets)
-	 local se = 0
+	 local se = eval_metric:forward(output, targets)
 	 if config.xInstanceLoss then
-	    if type(output) then
-	       local tbl = {}
-	       for i = 1, batch_size do
-		  for j = 1, #output do
-		     tbl[j] = output[j][i]
-		  end
-		  local el = eval_metric:forward(tbl, targets[i])
-		  se = se + el
-		  instance_loss[shuffle[t + i - 1]] = el
-	       end
-	       se = (se / batch_size)
-	    else
-	       for i = 1, batch_size do
-		  local el = eval_metric:forward(output[i], targets[i])
-		  se = se + el
-		  instance_loss[shuffle[t + i - 1]] = el
-	       end
-	       se = (se / batch_size)
-	    end	       
-	 else
-	    se = eval_metric:forward(output, targets)
+	    assert(eval_metric.instance_loss, "eval metric does not support instalce_loss")
+	    for i = 1, #eval_metric.instance_loss do
+	       instance_loss[shuffle[t + i - 1]] = eval_metric.instance_loss[i]
+	    end
 	 end
 	 sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6)))
 	 sum_eval = sum_eval + se