|
@@ -262,6 +262,7 @@ local function make_validation_set(x, n, patches)
|
|
return data
|
|
return data
|
|
end
|
|
end
|
|
local function validate(model, criterion, eval_metric, data, batch_size)
|
|
local function validate(model, criterion, eval_metric, data, batch_size)
|
|
|
|
+ local psnr = 0
|
|
local loss = 0
|
|
local loss = 0
|
|
local mse = 0
|
|
local mse = 0
|
|
local loss_count = 0
|
|
local loss_count = 0
|
|
@@ -286,8 +287,10 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|
inputs:copy(inputs_tmp)
|
|
inputs:copy(inputs_tmp)
|
|
targets:copy(targets_tmp)
|
|
targets:copy(targets_tmp)
|
|
local z = model:forward(inputs)
|
|
local z = model:forward(inputs)
|
|
|
|
+ local batch_mse = eval_metric:forward(z, targets)
|
|
loss = loss + criterion:forward(z, targets)
|
|
loss = loss + criterion:forward(z, targets)
|
|
- mse = mse + eval_metric:forward(z, targets)
|
|
|
|
|
|
+ mse = mse + batch_mse
|
|
|
|
+ psnr = psnr + (10 * math.log10(1 / batch_mse))
|
|
loss_count = loss_count + 1
|
|
loss_count = loss_count + 1
|
|
if loss_count % 10 == 0 then
|
|
if loss_count % 10 == 0 then
|
|
xlua.progress(t, #data)
|
|
xlua.progress(t, #data)
|
|
@@ -295,7 +298,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
xlua.progress(#data, #data)
|
|
xlua.progress(#data, #data)
|
|
- return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
|
|
|
|
|
|
+ return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = psnr / loss_count}
|
|
end
|
|
end
|
|
|
|
|
|
local function create_criterion(model)
|
|
local function create_criterion(model)
|