浏览代码

aux_huber and lbp loss

nagadomi 6 年之前
父节点
当前提交
f317545732
共有 2 个文件被更改,包括 30 次插入3 次删除
  1. 3 1
      lib/Print.lua
  2. 27 2
      train.lua

+ 3 - 1
lib/Print.lua

@@ -1,9 +1,11 @@
 local Print, parent = torch.class('w2nn.Print','nn.Module')
 
-function Print:__init()
+function Print:__init(id)
    parent.__init(self)
+   self.id = id
 end
 function Print:updateOutput(input)
+   print("----", self.id)
    print(input:size())
    self.output:resizeAs(input)
    self.output:copy(input)

+ 27 - 2
train.lua

@@ -369,6 +369,31 @@ local function create_criterion(model)
       local aux = w2nn.AuxiliaryLossCriterion(nn.BCECriterion)
       aux.sizeAverage = true
       return aux:cuda()
+   elseif settings.loss == "aux_huber" then
+      local args = {}
+      if reconstruct.is_rgb(model) then
+	 local offset = reconstruct.offset_size(model)
+	 local output_w = settings.crop_size - offset * 2
+	 local weight = torch.Tensor(3, output_w * output_w)
+	 weight[1]:fill(0.29891 * 3) -- R
+	 weight[2]:fill(0.58661 * 3) -- G
+	 weight[3]:fill(0.11448 * 3) -- B
+	 args = {weight, 0.1, {0.0, 1.0}}
+      else
+	 local offset = reconstruct.offset_size(model)
+	 local output_w = settings.crop_size - offset * 2
+	 local weight = torch.Tensor(1, output_w * output_w)
+	 weight[1]:fill(1.0)
+	 args = {weight, 0.1, {0.0, 1.0}}
+      end
+      local aux = w2nn.AuxiliaryLossCriterion(w2nn.ClippedWeightedHuberCriterion, args)
+      return aux:cuda()
+   elseif settings.loss == "lbp" then
+      if reconstruct.is_rgb(model) then
+	 return w2nn.RandomBinaryCriterion(3, 512):cuda()
+      else
+	 return w2nn.RandomBinaryCriterion(1, 512):cuda()
+      end
    else
       error("unsupported loss .." .. settings.loss)
    end
@@ -506,9 +531,9 @@ local function train()
    local criterion = create_criterion(model)
    local eval_metric = nil
    if settings.loss:find("aux_") ~= nil then
-      eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion)
+      eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion):cuda()
    else
-      eval_metric = w2nn.ClippedMSECriterion()
+      eval_metric = w2nn.ClippedMSECriterion():cuda()
    end
    local adam_config = {
       xLearningRate = settings.learning_rate,