Prechádzať zdrojové kódy

lbp loss: multilayer support

nagadomi 6 rokov pred
rodič
commit
b3fb286258
2 zmenil súbory, kde vykonal 35 pridanie a 9 odobranie
  1. 23 9
      lib/RandomBinaryCriterion.lua
  2. 12 0
      train.lua

+ 23 - 9
lib/RandomBinaryCriterion.lua

@@ -1,21 +1,35 @@
 local RandomBinaryCriterion, parent = torch.class('w2nn.RandomBinaryCriterion','nn.Criterion')
 
-local function create_filters(ch, n, k)
-   local filter = w2nn.RandomBinaryConvolution(ch, n, k, k)
-   -- channel identify
-   for i = 1, ch do
-      filter.weight[i]:fill(0)
-      filter.weight[i][i][math.floor(k/2)+1][math.floor(k/2)+1] = 1
+local function create_filters(ch, n, k, layers)
+   local model = nn.Sequential()
+   for i = 1, layers do
+      local n_input = ch
+      if i > 1 then
+	 n_input = n
+      end
+      local filter = w2nn.RandomBinaryConvolution(n_input, n, k, k)
+      if i == 1 then
+	 -- channel identity
+	 for j = 1, ch do
+	    filter.weight[i]:fill(0)
+	    filter.weight[i][i][math.floor(k/2)+1][math.floor(k/2)+1] = 1
+	 end
+      end
+      model:add(filter)
+      --if layers > 1 and i ~= layers then
+      --   model:add(nn.Sigmoid(true))
+      --end
    end
-   return filter
+   return model
 end
-function RandomBinaryCriterion:__init(ch, n, k)
+function RandomBinaryCriterion:__init(ch, n, k, layers)
    parent.__init(self)
+   self.layers = layers or 1
    self.gamma = 0.1
    self.n = n or 32
    self.k = k or 3
    self.ch = ch
-   self.filter1 = create_filters(self.ch, self.n, self.k)
+   self.filter1 = create_filters(self.ch, self.n, self.k, self.layers)
    self.filter2 = self.filter1:clone()
    self.diff = torch.Tensor()
    self.diff_abs = torch.Tensor()

+ 12 - 0
train.lua

@@ -394,12 +394,24 @@ local function create_criterion(model)
       else
 	 return w2nn.RandomBinaryCriterion(1, 128):cuda()
       end
+   elseif settings.loss == "lbp2" then
+      if reconstruct.is_rgb(model) then
+	 return w2nn.RandomBinaryCriterion(3, 128, 3, 2):cuda()
+      else
+	 return w2nn.RandomBinaryCriterion(1, 128, 3, 2):cuda()
+      end
    elseif settings.loss == "aux_lbp" then
       if reconstruct.is_rgb(model) then
 	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128}):cuda()
       else
 	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128}):cuda()
       end
+   elseif settings.loss == "aux_lbp2" then
+      if reconstruct.is_rgb(model) then
+	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128, 3, 2}):cuda()
+      else
+	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128, 3, 2}):cuda()
+      end
    else
       error("unsupported loss .." .. settings.loss)
    end