Browse Source

Fix division by zero error in validate()

nagadomi 8 năm trước cách đây
mục cha
commit
dac1b89750
1 tập tin đã thay đổi với 1 bổ sung1 xóa
  1. 1 1
      train.lua

+ 1 - 1
train.lua

@@ -290,7 +290,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
       local batch_mse = eval_metric:forward(z, targets)
       loss = loss + criterion:forward(z, targets)
       mse = mse + batch_mse
-      psnr = psnr + (10 * math.log10(1 / batch_mse))
+      psnr = psnr + (10 * math.log10(1 / (batch_mse + 1.0e-6)))
       loss_count = loss_count + 1
       if loss_count % 10 == 0 then
 	 xlua.progress(t, #data)