Bladeren bron

Add -save_history option

nagadomi 9 jaren geleden
bovenliggende
commit
9f935835dd
2 gewijzigde bestanden met toevoegingen van 51 en 19 verwijderingen
  1. 25 10
      lib/settings.lua
  2. 26 9
      train.lua

+ 25 - 10
lib/settings.lua

@@ -46,22 +46,37 @@ cmd:option("-validation_crops", 80, 'number of cropping region per image in vali
 cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
 cmd:option("-active_cropping_tries", 10, 'active cropping tries')
 cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
+cmd:option("-save_history", 0, 'save all model (0|1)')
 
 local opt = cmd:parse(arg)
 for k, v in pairs(opt) do
    settings[k] = v
 end
-if settings.method == "noise" then
-   settings.model_file = string.format("%s/noise%d_model.t7",
-				       settings.model_dir, settings.noise_level)
-elseif settings.method == "scale" then
-   settings.model_file = string.format("%s/scale%.1fx_model.t7",
-				       settings.model_dir, settings.scale)
-elseif settings.method == "noise_scale" then
-   settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
-				       settings.model_dir, settings.noise_level, settings.scale)
+if settings.save_history == 1 then
+   settings.save_history = true
 else
-   error("unknown method: " .. settings.method)
+   settings.save_history = false
+end
+if settings.save_history then
+   if settings.method == "noise" then
+      settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7",
+					  settings.model_dir, settings.noise_level)
+   elseif settings.method == "scale" then
+      settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7",
+					  settings.model_dir, settings.scale)
+   else
+      error("unknown method: " .. settings.method)
+   end
+else
+   if settings.method == "noise" then
+      settings.model_file = string.format("%s/noise%d_model.t7",
+					  settings.model_dir, settings.noise_level)
+   elseif settings.method == "scale" then
+      settings.model_file = string.format("%s/scale%.1fx_model.t7",
+					  settings.model_dir, settings.scale)
+   else
+      error("unknown method: " .. settings.method)
+   end
 end
 if not (settings.color == "rgb" or settings.color == "y") then
    error("color must be y or rgb")

+ 26 - 9
train.lua

@@ -205,15 +205,32 @@ local function train()
 	    lrd_count = 0
 	    best_score = score
 	    print("* update best model")
-	    torch.save(settings.model_file, model)
-	    if settings.method == "noise" then
-	       local log = path.join(settings.model_dir,
-				     ("noise%d_best.png"):format(settings.noise_level))
-	       save_test_jpeg(model, test_image, log)
-	    elseif settings.method == "scale" then
-	       local log = path.join(settings.model_dir,
-				     ("scale%.1f_best.png"):format(settings.scale))
-	       save_test_scale(model, test_image, log)
+	    if settings.save_history then
+	       local model_clone = model:clone()
+	       w2nn.cleanup_model(model_clone)
+	       torch.save(string.format(settings.model_file, epoch, i), model_clone)
+	       if settings.method == "noise" then
+		  local log = path.join(settings.model_dir,
+					("noise%d_best.%d-%d.png"):format(settings.noise_level,
+									  epoch, i))
+		  save_test_jpeg(model, test_image, log)
+	       elseif settings.method == "scale" then
+		  local log = path.join(settings.model_dir,
+					("scale%.1f_best.%d-%d.png"):format(settings.scale,
+									    epoch, i))
+		  save_test_scale(model, test_image, log)
+	       end
+	    else
+	       torch.save(settings.model_file, model)
+	       if settings.method == "noise" then
+		  local log = path.join(settings.model_dir,
+					("noise%d_best.png"):format(settings.noise_level))
+		  save_test_jpeg(model, test_image, log)
+	       elseif settings.method == "scale" then
+		  local log = path.join(settings.model_dir,
+					("scale%.1f_best.png"):format(settings.scale))
+		  save_test_scale(model, test_image, log)
+	       end
 	    end
 	 else
 	    lrd_count = lrd_count + 1