|
@@ -382,6 +382,8 @@ local function plot(train, valid)
|
|
|
{'validation', torch.Tensor(valid), '-'}})
|
|
|
end
|
|
|
local function train()
|
|
|
+ local x = remove_small_image(torch.load(settings.images))
|
|
|
+ local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
|
|
local hist_train = {}
|
|
|
local hist_valid = {}
|
|
|
local model
|
|
@@ -397,8 +399,6 @@ local function train()
|
|
|
|
|
|
local criterion = create_criterion(model)
|
|
|
local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
|
|
|
- local x = remove_small_image(torch.load(settings.images))
|
|
|
- local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
|
|
local adam_config = {
|
|
|
xLearningRate = settings.learning_rate,
|
|
|
xBatchSize = settings.batch_size,
|