|
@@ -298,20 +298,26 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|
|
end
|
|
|
|
|
|
local function create_criterion(model)
|
|
|
- if reconstruct.is_rgb(model) then
|
|
|
- local offset = reconstruct.offset_size(model)
|
|
|
- local output_w = settings.crop_size - offset * 2
|
|
|
- local weight = torch.Tensor(3, output_w * output_w)
|
|
|
- weight[1]:fill(0.29891 * 3) -- R
|
|
|
- weight[2]:fill(0.58661 * 3) -- G
|
|
|
- weight[3]:fill(0.11448 * 3) -- B
|
|
|
- return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
|
|
+ if settings.loss == "huber" then
|
|
|
+ if reconstruct.is_rgb(model) then
|
|
|
+ local offset = reconstruct.offset_size(model)
|
|
|
+ local output_w = settings.crop_size - offset * 2
|
|
|
+ local weight = torch.Tensor(3, output_w * output_w)
|
|
|
+ weight[1]:fill(0.29891 * 3) -- R
|
|
|
+ weight[2]:fill(0.58661 * 3) -- G
|
|
|
+ weight[3]:fill(0.11448 * 3) -- B
|
|
|
+ return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
|
|
+ else
|
|
|
+ local offset = reconstruct.offset_size(model)
|
|
|
+ local output_w = settings.crop_size - offset * 2
|
|
|
+ local weight = torch.Tensor(1, output_w * output_w)
|
|
|
+ weight[1]:fill(1.0)
|
|
|
+ return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
|
|
+ end
|
|
|
+ elseif settings.loss == "l1" then
|
|
|
+ return w2nn.L1Criterion():cuda()
|
|
|
else
|
|
|
- local offset = reconstruct.offset_size(model)
|
|
|
- local output_w = settings.crop_size - offset * 2
|
|
|
- local weight = torch.Tensor(1, output_w * output_w)
|
|
|
- weight[1]:fill(1.0)
|
|
|
- return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
|
|
+ error("unsupported loss .." .. settings.loss)
|
|
|
end
|
|
|
end
|
|
|
|
|
@@ -518,9 +524,9 @@ local function train()
|
|
|
if settings.plot then
|
|
|
plot(hist_train, hist_valid)
|
|
|
end
|
|
|
- if score.MSE < best_score then
|
|
|
+ if score.loss < best_score then
|
|
|
local test_image = image_loader.load_float(settings.test) -- reload
|
|
|
- best_score = score.MSE
|
|
|
+ best_score = score.loss
|
|
|
print("* model has updated")
|
|
|
if settings.save_history then
|
|
|
torch.save(settings.model_file_best, model:clearState(), "ascii")
|
|
@@ -569,7 +575,7 @@ local function train()
|
|
|
end
|
|
|
end
|
|
|
end
|
|
|
- print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
|
|
|
+ print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
|
|
|
collectgarbage()
|
|
|
end
|
|
|
end
|