|
@@ -506,9 +506,24 @@ local function train()
|
|
|
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
|
|
local hist_train = {}
|
|
|
local hist_valid = {}
|
|
|
+ local adam_config = {
|
|
|
+ xLearningRate = settings.learning_rate,
|
|
|
+ xBatchSize = settings.batch_size,
|
|
|
+ xLearningRateDecay = settings.learning_rate_decay,
|
|
|
+ xInstanceLoss = (settings.oracle_rate > 0)
|
|
|
+ }
|
|
|
local model
|
|
|
if settings.resume:len() > 0 then
|
|
|
model = torch.load(settings.resume, "ascii")
|
|
|
+ adam_config.xEvalCount = math.floor((#train_x * settings.patches) / settings.batch_size) * settings.batch_size * settings.inner_epoch * (settings.resume_epoch - 1)
|
|
|
+ print(string.format("set eval count = %d", adam_config.xEvalCount))
|
|
|
+ if adam_config.xEvalCount > 0 then
|
|
|
+ adam_config.learningRate = adam_config.xLearningRate / (1 + adam_config.xEvalCount * adam_config.xLearningRateDecay)
|
|
|
+ print(string.format("set learning rate = %E", adam_config.learningRate))
|
|
|
+ else
|
|
|
+ adam_config.xEvalCount = 0
|
|
|
+ adam_config.learningRate = adam_config.xLearningRate
|
|
|
+ end
|
|
|
else
|
|
|
if stringx.endswith(settings.model, ".lua") then
|
|
|
local create_model = dofile(settings.model)
|
|
@@ -576,7 +591,7 @@ local function train()
|
|
|
end
|
|
|
local instance_loss = nil
|
|
|
local pmodel = w2nn.data_parallel(model, settings.gpu)
|
|
|
- for epoch = 1, settings.epoch do
|
|
|
+ for epoch = settings.resume_epoch, settings.epoch do
|
|
|
pmodel:training()
|
|
|
print("# " .. epoch)
|
|
|
if adam_config.learningRate then
|