Browse Source

Rename RandomBinaryCriterion to LBPCriterion

nagadomi 6 years ago
parent
commit
aea254eab5
3 changed files with 9 additions and 10 deletions
  1. 0 1
      lib/LBPCriterion.lua
  2. 1 1
      lib/w2nn.lua
  3. 8 8
      train.lua

+ 0 - 1
lib/RandomBinaryCriterion.lua → lib/LBPCriterion.lua

@@ -64,7 +64,6 @@ function RandomBinaryCriterion:updateOutput(input, target)
    local linear_targets = self.diff[torch.ge(self.diff_abs, self.gamma)]
    local linear_targets = self.diff[torch.ge(self.diff_abs, self.gamma)]
    local square_loss = self.square_loss_buff:resizeAs(square_targets):copy(square_targets):pow(2.0):mul(0.5):sum()
    local square_loss = self.square_loss_buff:resizeAs(square_targets):copy(square_targets):pow(2.0):mul(0.5):sum()
    local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():add(-0.5 * self.gamma):mul(self.gamma):sum()
    local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():add(-0.5 * self.gamma):mul(self.gamma):sum()
-
    --self.outlier_rate = linear_targets:nElement() / input:nElement()
    --self.outlier_rate = linear_targets:nElement() / input:nElement()
    self.output = (square_loss + linear_loss) / lb1:nElement()
    self.output = (square_loss + linear_loss) / lb1:nElement()
 
 

+ 1 - 1
lib/w2nn.lua

@@ -82,7 +82,7 @@ else
    require 'AuxiliaryLossCriterion'
    require 'AuxiliaryLossCriterion'
    require 'GradWeight'
    require 'GradWeight'
    require 'RandomBinaryConvolution'
    require 'RandomBinaryConvolution'
-   require 'RandomBinaryCriterion'
+   require 'LBPCriterion'
    require 'EdgeFilter'
    require 'EdgeFilter'
    require 'ScaleTable'
    require 'ScaleTable'
    return w2nn
    return w2nn

+ 8 - 8
train.lua

@@ -390,27 +390,27 @@ local function create_criterion(model)
       return aux:cuda()
       return aux:cuda()
    elseif settings.loss == "lbp" then
    elseif settings.loss == "lbp" then
       if reconstruct.is_rgb(model) then
       if reconstruct.is_rgb(model) then
-	 return w2nn.RandomBinaryCriterion(3, 128):cuda()
+	 return w2nn.LBPCriterion(3, 128):cuda()
       else
       else
-	 return w2nn.RandomBinaryCriterion(1, 128):cuda()
+	 return w2nn.LBPCriterion(1, 128):cuda()
       end
       end
    elseif settings.loss == "lbp2" then
    elseif settings.loss == "lbp2" then
       if reconstruct.is_rgb(model) then
       if reconstruct.is_rgb(model) then
-	 return w2nn.RandomBinaryCriterion(3, 128, 3, 2):cuda()
+	 return w2nn.LBPCriterion(3, 128, 3, 2):cuda()
       else
       else
-	 return w2nn.RandomBinaryCriterion(1, 128, 3, 2):cuda()
+	 return w2nn.LBPCriterion(1, 128, 3, 2):cuda()
       end
       end
    elseif settings.loss == "aux_lbp" then
    elseif settings.loss == "aux_lbp" then
       if reconstruct.is_rgb(model) then
       if reconstruct.is_rgb(model) then
-	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128}):cuda()
+	 return w2nn.AuxiliaryLossCriterion(w2nn.LBPCriterion, {3, 128}):cuda()
       else
       else
-	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128}):cuda()
+	 return w2nn.AuxiliaryLossCriterion(w2nn.LBPCriterion, {1, 128}):cuda()
       end
       end
    elseif settings.loss == "aux_lbp2" then
    elseif settings.loss == "aux_lbp2" then
       if reconstruct.is_rgb(model) then
       if reconstruct.is_rgb(model) then
-	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128, 3, 2}):cuda()
+	 return w2nn.AuxiliaryLossCriterion(w2nn.LBPCriterion, {3, 128, 3, 2}):cuda()
       else
       else
-	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128, 3, 2}):cuda()
+	 return w2nn.AuxiliaryLossCriterion(w2nn.LBPCriterion, {1, 128, 3, 2}):cuda()
       end
       end
    else
    else
       error("unsupported loss .." .. settings.loss)
       error("unsupported loss .." .. settings.loss)