Browse Source

Fix clearState

nagadomi 8 years ago
parent
commit
063474b2ea
1 changed files with 5 additions and 3 deletions
  1. 5 3
      train.lua

+ 5 - 3
train.lua

@@ -543,8 +543,9 @@ local function train()
 	    best_score = score_for_update
 	    best_score = score_for_update
 	    print("* model has updated")
 	    print("* model has updated")
 	    if settings.save_history then
 	    if settings.save_history then
-	       torch.save(settings.model_file_best, model:clearState(), "ascii")
-	       torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
+	       pmodel:clearState()
+	       torch.save(settings.model_file_best, model, "ascii")
+	       torch.save(string.format(settings.model_file, epoch, i), model, "ascii")
 	       if settings.method == "noise" then
 	       if settings.method == "noise" then
 		  local log = path.join(settings.model_dir,
 		  local log = path.join(settings.model_dir,
 					("noise%d_best.%d-%d.png"):format(settings.noise_level,
 					("noise%d_best.%d-%d.png"):format(settings.noise_level,
@@ -568,7 +569,8 @@ local function train()
 		  save_test_user(model, test_image, log)
 		  save_test_user(model, test_image, log)
 	       end
 	       end
 	    else
 	    else
-	       torch.save(settings.model_file, model:clearState(), "ascii")
+	       pmodel:clearState()
+	       torch.save(settings.model_file, model, "ascii")
 	       if settings.method == "noise" then
 	       if settings.method == "noise" then
 		  local log = path.join(settings.model_dir,
 		  local log = path.join(settings.model_dir,
 					("noise%d_best.png"):format(settings.noise_level))
 					("noise%d_best.png"):format(settings.noise_level))