Ver código fonte

Fix training mode

nagadomi 9 anos atrás
pai
commit
634046d5f0
1 arquivos alterados com 1 adições e 0 exclusões
  1. 1 0
      train.lua

+ 1 - 0
train.lua

@@ -309,6 +309,7 @@ local function train()
       instance_loss = torch.Tensor(x:size(1)):zero()
 
       for i = 1, settings.inner_epoch do
+	 model:training()
 	 local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
 	 instance_loss:copy(il)
 	 print(train_score)