|
@@ -262,6 +262,7 @@ local function make_validation_set(x, n, patches)
|
|
|
return data
|
|
|
end
|
|
|
local function validate(model, criterion, eval_metric, data, batch_size)
|
|
|
+ local psnr = 0
|
|
|
local loss = 0
|
|
|
local mse = 0
|
|
|
local loss_count = 0
|
|
@@ -286,8 +287,10 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|
|
inputs:copy(inputs_tmp)
|
|
|
targets:copy(targets_tmp)
|
|
|
local z = model:forward(inputs)
|
|
|
+ local batch_mse = eval_metric:forward(z, targets)
|
|
|
loss = loss + criterion:forward(z, targets)
|
|
|
- mse = mse + eval_metric:forward(z, targets)
|
|
|
+ mse = mse + batch_mse
|
|
|
+ psnr = psnr + (10 * math.log10(1 / batch_mse))
|
|
|
loss_count = loss_count + 1
|
|
|
if loss_count % 10 == 0 then
|
|
|
xlua.progress(t, #data)
|
|
@@ -295,7 +298,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|
|
end
|
|
|
end
|
|
|
xlua.progress(#data, #data)
|
|
|
- return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
|
|
|
+ return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = psnr / loss_count}
|
|
|
end
|
|
|
|
|
|
local function create_criterion(model)
|
|
@@ -540,9 +543,15 @@ local function train()
|
|
|
if settings.plot then
|
|
|
plot(hist_train, hist_valid)
|
|
|
end
|
|
|
- if score.loss < best_score then
|
|
|
+ local score_for_update
|
|
|
+ if settings.update_criterion == "mse" then
|
|
|
+ score_for_update = score.MSE
|
|
|
+ else
|
|
|
+ score_for_update = score.loss
|
|
|
+ end
|
|
|
+ if score_for_update < best_score then
|
|
|
local test_image = image_loader.load_float(settings.test) -- reload
|
|
|
- best_score = score.loss
|
|
|
+ best_score = score_for_update
|
|
|
print("* model has updated")
|
|
|
if settings.save_history then
|
|
|
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
|
@@ -591,7 +600,7 @@ local function train()
|
|
|
end
|
|
|
end
|
|
|
end
|
|
|
- print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
|
|
|
+ print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", best: " .. best_score)
|
|
|
collectgarbage()
|
|
|
end
|
|
|
end
|