Browse Source

Add -resume_epoch option

nagadomi 6 năm trước cách đây
mục cha
commit
bdafab9e10
2 tập tin đã thay đổi với 17 bổ sung1 xóa
  1. 1 0
      lib/settings.lua
  2. 16 1
      train.lua

+ 1 - 0
lib/settings.lua

@@ -76,6 +76,7 @@ cmd:option("-oracle_rate", 0.1, '')
 cmd:option("-oracle_drop_rate", 0.5, '')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
 cmd:option("-resume", "", 'resume model file')
+cmd:option("-resume_epoch", 1, 'resume epoch')
 cmd:option("-name", "user", 'model name for user method')
 cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
 cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')

+ 16 - 1
train.lua

@@ -506,9 +506,24 @@ local function train()
    local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
    local hist_train = {}
    local hist_valid = {}
+   local adam_config = {
+      xLearningRate = settings.learning_rate,
+      xBatchSize = settings.batch_size,
+      xLearningRateDecay = settings.learning_rate_decay,
+      xInstanceLoss = (settings.oracle_rate > 0)
+   }
    local model
    if settings.resume:len() > 0 then
       model = torch.load(settings.resume, "ascii")
+      adam_config.xEvalCount = math.floor((#train_x * settings.patches) / settings.batch_size) * settings.batch_size * settings.inner_epoch * (settings.resume_epoch - 1)
+      print(string.format("set eval count = %d", adam_config.xEvalCount))
+      if adam_config.xEvalCount > 0 then
+	 adam_config.learningRate = adam_config.xLearningRate / (1 + adam_config.xEvalCount * adam_config.xLearningRateDecay)
+	 print(string.format("set learning rate = %E", adam_config.learningRate))
+      else
+	 adam_config.xEvalCount = 0
+	 adam_config.learningRate = adam_config.xLearningRate
+      end
    else
       if stringx.endswith(settings.model, ".lua") then
 	 local create_model = dofile(settings.model)
@@ -576,7 +591,7 @@ local function train()
    end
    local instance_loss = nil
    local pmodel = w2nn.data_parallel(model, settings.gpu)
-   for epoch = 1, settings.epoch do
+   for epoch = settings.resume_epoch, settings.epoch do
       pmodel:training()
       print("# " .. epoch)
       if adam_config.learningRate then