瀏覽代碼

Add -resume option

nagadomi 9 年之前
父節點
當前提交
e5cfd3dfce
共有 2 個文件被更改,包括 7 次插入1 次删除
  1. 1 0
      lib/settings.lua
  2. 6 1
      train.lua

+ 1 - 0
lib/settings.lua

@@ -60,6 +60,7 @@ cmd:option("-oracle_rate", 0.1, '')
 cmd:option("-oracle_drop_rate", 0.5, '')
 cmd:option("-oracle_drop_rate", 0.5, '')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
 cmd:option("-loss", "y", 'loss (rgb|y)')
 cmd:option("-loss", "y", 'loss (rgb|y)')
+cmd:option("-resume", "", 'resume model file')
 
 
 local function to_bool(settings, name)
 local function to_bool(settings, name)
    if settings[name] == 1 then
    if settings[name] == 1 then

+ 6 - 1
train.lua

@@ -278,7 +278,12 @@ end
 local function train()
 local function train()
    local hist_train = {}
    local hist_train = {}
    local hist_valid = {}
    local hist_valid = {}
-   local model = srcnn.create(settings.model, settings.backend, settings.color)
+   local model
+   if settings.resume:len() > 0 then
+      model = torch.load(settings.resume, "ascii")
+   else
+      model = srcnn.create(settings.model, settings.backend, settings.color)
+   end
    local offset = reconstruct.offset_size(model)
    local offset = reconstruct.offset_size(model)
    local pairwise_func = function(x, is_validation, n)
    local pairwise_func = function(x, is_validation, n)
       return transformer(model, x, is_validation, n, offset)
       return transformer(model, x, is_validation, n, offset)