소스 검색

Add aux_lbp; Fix oracle

nagadomi 6 년 전
부모
커밋
89c8f5db8e
2개의 변경된 파일25개의 추가작업 그리고 6개의 파일을 삭제
  1. 19 6
      lib/minibatch_adam.lua
  2. 6 0
      train.lua

+ 19 - 6
lib/minibatch_adam.lua

@@ -46,12 +46,25 @@ local function minibatch_adam(model, criterion, eval_metric,
 	 local f = criterion:forward(output, targets)
 	 local se = 0
 	 if config.xInstanceLoss then
-	    for i = 1, batch_size do
-	       local el = eval_metric:forward(output[i], targets[i])
-	       se = se + el
-	       instance_loss[shuffle[t + i - 1]] = el
-	    end
-	    se = (se / batch_size)
+	    if type(output) then
+	       local tbl = {}
+	       for i = 1, batch_size do
+		  for j = 1, #output do
+		     tbl[j] = output[j][i]
+		  end
+		  local el = eval_metric:forward(tbl, targets[i])
+		  se = se + el
+		  instance_loss[shuffle[t + i - 1]] = el
+	       end
+	       se = (se / batch_size)
+	    else
+	       for i = 1, batch_size do
+		  local el = eval_metric:forward(output[i], targets[i])
+		  se = se + el
+		  instance_loss[shuffle[t + i - 1]] = el
+	       end
+	       se = (se / batch_size)
+	    end	       
 	 else
 	    se = eval_metric:forward(output, targets)
 	 end

+ 6 - 0
train.lua

@@ -394,6 +394,12 @@ local function create_criterion(model)
       else
 	 return w2nn.RandomBinaryCriterion(1, 512):cuda()
       end
+   elseif settings.loss == "aux_lbp" then
+      if reconstruct.is_rgb(model) then
+	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 512}):cuda()
+      else
+	 return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 512}):cuda()
+      end
    else
       error("unsupported loss .." .. settings.loss)
    end