Browse Source

Stop calculate the instance loss when oracle_rate=0

nagadomi 8 năm trước cách đây
mục cha
commit
451ee1407f
2 tập tin đã thay đổi với 12 bổ sung6 xóa
  1. 10 5
      lib/minibatch_adam.lua
  2. 2 1
      train.lua

+ 10 - 5
lib/minibatch_adam.lua

@@ -45,12 +45,17 @@ local function minibatch_adam(model, criterion, eval_metric,
 	 local output = model:forward(inputs)
 	 local f = criterion:forward(output, targets)
 	 local se = 0
-	 for i = 1, batch_size do
-	    local el = eval_metric:forward(output[i], targets[i])
-	    se = se + el
-	    instance_loss[shuffle[t + i - 1]] = el
+	 if config.xInstanceLoss then
+	    for i = 1, batch_size do
+	       local el = eval_metric:forward(output[i], targets[i])
+	       se = se + el
+	       instance_loss[shuffle[t + i - 1]] = el
+	    end
+	    se = (se / batch_size)
+	 else
+	    se = eval_metric:forward(output, targets)
 	 end
-	 sum_eval = sum_eval + (se / batch_size)
+	 sum_eval = sum_eval + se
 	 sum_loss = sum_loss + f
 	 count_loss = count_loss + 1
 	 model:backward(inputs, criterion:backward(output, targets))

+ 2 - 1
train.lua

@@ -374,7 +374,8 @@ local function train()
    local adam_config = {
       xLearningRate = settings.learning_rate,
       xBatchSize = settings.batch_size,
-      xLearningRateDecay = settings.learning_rate_decay
+      xLearningRateDecay = settings.learning_rate_decay,
+      xInstanceLoss = (settings.oracle_rate > 0)
    }
    local ch = nil
    if settings.color == "y" then