|
@@ -166,6 +166,7 @@ local function train()
|
|
|
return transformer(x, is_validation, n, offset)
|
|
|
end
|
|
|
local criterion = create_criterion(model)
|
|
|
+ local eval_metric = w2nn.PSNRCriterion():cuda()
|
|
|
local x = torch.load(settings.images)
|
|
|
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
|
|
local adam_config = {
|
|
@@ -179,7 +180,7 @@ local function train()
|
|
|
elseif settings.color == "rgb" then
|
|
|
ch = 3
|
|
|
end
|
|
|
- local best_score = 100000.0
|
|
|
+ local best_score = 0.0
|
|
|
print("# make validation-set")
|
|
|
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
|
|
settings.validation_crops,
|
|
@@ -200,11 +201,11 @@ local function train()
|
|
|
print("# " .. epoch)
|
|
|
resampling(x, y, train_x, pairwise_func)
|
|
|
for i = 1, settings.inner_epoch do
|
|
|
- print(minibatch_adam(model, criterion, x, y, adam_config))
|
|
|
+ print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config))
|
|
|
model:evaluate()
|
|
|
print("# validation")
|
|
|
- local score = validate(model, criterion, valid_xy)
|
|
|
- if score < best_score then
|
|
|
+ local score = validate(model, eval_metric, valid_xy)
|
|
|
+ if score > best_score then
|
|
|
local test_image = image_loader.load_float(settings.test) -- reload
|
|
|
lrd_count = 0
|
|
|
best_score = score
|