Browse Source

Fix for edac608f

nagadomi 9 years ago
parent
commit
06e073089b
1 changed files with 15 additions and 4 deletions
  1. 15 4
      lib/settings.lua

+ 15 - 4
lib/settings.lua

@@ -23,9 +23,9 @@ cmd:option("-data_dir", "./data", 'path to data directory')
 cmd:option("-backend", "cunn", '(cunn|cudnn)')
 cmd:option("-backend", "cunn", '(cunn|cudnn)')
 cmd:option("-test", "images/miku_small.png", 'path to test image')
 cmd:option("-test", "images/miku_small.png", 'path to test image')
 cmd:option("-model_dir", "./models", 'model directory')
 cmd:option("-model_dir", "./models", 'model directory')
-cmd:option("-method", "scale", 'method to training (noise|scale|noise_scale)')
+cmd:option("-method", "scale", 'method to training (noise|scale|noise_scale|user)')
 cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)')
 cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)')
-cmd:option("-noise_level", 1, '(1|2|3)')
+cmd:option("-noise_level", 1, '(0|1|2|3)')
 cmd:option("-style", "art", '(art|photo)')
 cmd:option("-style", "art", '(art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')
 cmd:option("-color", 'rgb', '(y|rgb)')
 cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
 cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
@@ -59,6 +59,7 @@ cmd:option("-oracle_drop_rate", 0.5, '')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
 cmd:option("-loss", "y", 'loss (rgb|y)')
 cmd:option("-loss", "y", 'loss (rgb|y)')
 cmd:option("-resume", "", 'resume model file')
 cmd:option("-resume", "", 'resume model file')
+cmd:option("-name", "user", 'model name for user method')
 
 
 local function to_bool(settings, name)
 local function to_bool(settings, name)
    if settings[name] == 1 then
    if settings[name] == 1 then
@@ -99,6 +100,13 @@ if settings.save_history then
 					       settings.model_dir,
 					       settings.model_dir,
 					       settings.noise_level, 
 					       settings.noise_level, 
 					       settings.scale)
 					       settings.scale)
+   elseif settings.method == "user" then
+      settings.model_file = string.format("%s/%s_model.%%d-%%d.t7",
+					  settings.model_dir,
+					  settings.name)
+      settings.model_file_best = string.format("%s/%s_model.t7",
+					       settings.model_dir,
+					       settings.name)
    else
    else
       error("unknown method: " .. settings.method)
       error("unknown method: " .. settings.method)
    end
    end
@@ -112,6 +120,9 @@ else
    elseif settings.method == "noise_scale" then
    elseif settings.method == "noise_scale" then
       settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
       settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
 					  settings.model_dir, settings.noise_level, settings.scale)
 					  settings.model_dir, settings.noise_level, settings.scale)
+   elseif settings.method == "user" then
+      settings.model_file = string.format("%s/%s_model.t7",
+					  settings.model_dir, settings.name)
    else
    else
       error("unknown method: " .. settings.method)
       error("unknown method: " .. settings.method)
    end
    end
@@ -119,8 +130,8 @@ end
 if not (settings.color == "rgb" or settings.color == "y") then
 if not (settings.color == "rgb" or settings.color == "y") then
    error("color must be y or rgb")
    error("color must be y or rgb")
 end
 end
-if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
-   error("scale must be mod-2")
+if not ( settings.scale == 1 or (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0)) then
+   error("scale must be 1 or mod-2")
 end
 end
 if not (settings.style == "art" or
 if not (settings.style == "art" or
 	settings.style == "photo") then
 	settings.style == "photo") then