Преглед изворни кода

add -update_criterion option for back compatible

nagadomi пре 8 година
родитељ
комит
b111901cbb
2 измењених фајлова са 10 додато и 3 уклоњено
  1. 1 0
      lib/settings.lua
  2. 9 3
      train.lua

+ 1 - 0
lib/settings.lua

@@ -76,6 +76,7 @@ cmd:option("-resume", "", 'resume model file')
 cmd:option("-name", "user", 'model name for user method')
 cmd:option("-gpu", 1, 'Device ID')
 cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
+cmd:option("-update_criterion", "mse", 'mse|loss')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then

+ 9 - 3
train.lua

@@ -532,9 +532,15 @@ local function train()
 	 if settings.plot then
 	    plot(hist_train, hist_valid)
 	 end
-	 if score.loss < best_score then
+	 local score_for_update
+	 if settings.update_criterion == "mse" then
+	    score_for_update = score.MSE
+	 else
+	    score_for_update = score.loss
+	 end
+	 if score_for_update < best_score then
 	    local test_image = image_loader.load_float(settings.test) -- reload
-	    best_score = score.loss
+	    best_score = score_for_update
 	    print("* model has updated")
 	    if settings.save_history then
 	       torch.save(settings.model_file_best, model:clearState(), "ascii")
@@ -583,7 +589,7 @@ local function train()
 	       end
 	    end
 	 end
-	 print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
+	 print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", best: " .. best_score)
 	 collectgarbage()
       end
    end