|
@@ -198,7 +198,7 @@ local function train()
|
|
return transformer(x, is_validation, n, offset)
|
|
return transformer(x, is_validation, n, offset)
|
|
end
|
|
end
|
|
local criterion = create_criterion(model)
|
|
local criterion = create_criterion(model)
|
|
- local eval_metric = w2nn.PSNRCriterion():cuda()
|
|
|
|
|
|
+ local eval_metric = nn.MSECriterion():cuda()
|
|
local x = torch.load(settings.images)
|
|
local x = torch.load(settings.images)
|
|
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
|
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
|
local adam_config = {
|
|
local adam_config = {
|
|
@@ -212,7 +212,7 @@ local function train()
|
|
elseif settings.color == "rgb" then
|
|
elseif settings.color == "rgb" then
|
|
ch = 3
|
|
ch = 3
|
|
end
|
|
end
|
|
- local best_score = 0.0
|
|
|
|
|
|
+ local best_score = 1000.0
|
|
print("# make validation-set")
|
|
print("# make validation-set")
|
|
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
|
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
|
settings.validation_crops,
|
|
settings.validation_crops,
|
|
@@ -227,7 +227,6 @@ local function train()
|
|
ch, settings.crop_size, settings.crop_size)
|
|
ch, settings.crop_size, settings.crop_size)
|
|
local y = torch.Tensor(settings.patches * #train_x,
|
|
local y = torch.Tensor(settings.patches * #train_x,
|
|
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
|
|
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
|
|
-
|
|
|
|
for epoch = 1, settings.epoch do
|
|
for epoch = 1, settings.epoch do
|
|
model:training()
|
|
model:training()
|
|
print("# " .. epoch)
|
|
print("# " .. epoch)
|
|
@@ -238,12 +237,12 @@ local function train()
|
|
model:evaluate()
|
|
model:evaluate()
|
|
print("# validation")
|
|
print("# validation")
|
|
local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
|
|
local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
|
|
- table.insert(hist_train, train_score.PSNR)
|
|
|
|
|
|
+ table.insert(hist_train, train_score.MSE)
|
|
table.insert(hist_valid, score)
|
|
table.insert(hist_valid, score)
|
|
if settings.plot then
|
|
if settings.plot then
|
|
plot(hist_train, hist_valid)
|
|
plot(hist_train, hist_valid)
|
|
end
|
|
end
|
|
- if score > best_score then
|
|
|
|
|
|
+ if score < best_score then
|
|
local test_image = image_loader.load_float(settings.test) -- reload
|
|
local test_image = image_loader.load_float(settings.test) -- reload
|
|
lrd_count = 0
|
|
lrd_count = 0
|
|
best_score = score
|
|
best_score = score
|
|
@@ -281,7 +280,7 @@ local function train()
|
|
lrd_count = 0
|
|
lrd_count = 0
|
|
end
|
|
end
|
|
end
|
|
end
|
|
- print("current: " .. score .. ", best: " .. best_score)
|
|
|
|
|
|
+ print("PSNR: " .. 10 * math.log10(1 / score) .. ", MSE: " .. score .. ", Best MSE: " .. best_score)
|
|
collectgarbage()
|
|
collectgarbage()
|
|
end
|
|
end
|
|
end
|
|
end
|