|
@@ -369,6 +369,31 @@ local function create_criterion(model)
|
|
|
local aux = w2nn.AuxiliaryLossCriterion(nn.BCECriterion)
|
|
|
aux.sizeAverage = true
|
|
|
return aux:cuda()
|
|
|
+ elseif settings.loss == "aux_huber" then
|
|
|
+ local args = {}
|
|
|
+ if reconstruct.is_rgb(model) then
|
|
|
+ local offset = reconstruct.offset_size(model)
|
|
|
+ local output_w = settings.crop_size - offset * 2
|
|
|
+ local weight = torch.Tensor(3, output_w * output_w)
|
|
|
+ weight[1]:fill(0.29891 * 3) -- R
|
|
|
+ weight[2]:fill(0.58661 * 3) -- G
|
|
|
+ weight[3]:fill(0.11448 * 3) -- B
|
|
|
+ args = {weight, 0.1, {0.0, 1.0}}
|
|
|
+ else
|
|
|
+ local offset = reconstruct.offset_size(model)
|
|
|
+ local output_w = settings.crop_size - offset * 2
|
|
|
+ local weight = torch.Tensor(1, output_w * output_w)
|
|
|
+ weight[1]:fill(1.0)
|
|
|
+ args = {weight, 0.1, {0.0, 1.0}}
|
|
|
+ end
|
|
|
+ local aux = w2nn.AuxiliaryLossCriterion(w2nn.ClippedWeightedHuberCriterion, args)
|
|
|
+ return aux:cuda()
|
|
|
+ elseif settings.loss == "lbp" then
|
|
|
+ if reconstruct.is_rgb(model) then
|
|
|
+ return w2nn.RandomBinaryCriterion(3, 512):cuda()
|
|
|
+ else
|
|
|
+ return w2nn.RandomBinaryCriterion(1, 512):cuda()
|
|
|
+ end
|
|
|
else
|
|
|
error("unsupported loss .." .. settings.loss)
|
|
|
end
|
|
@@ -506,9 +531,9 @@ local function train()
|
|
|
local criterion = create_criterion(model)
|
|
|
local eval_metric = nil
|
|
|
if settings.loss:find("aux_") ~= nil then
|
|
|
- eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion)
|
|
|
+ eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion):cuda()
|
|
|
else
|
|
|
- eval_metric = w2nn.ClippedMSECriterion()
|
|
|
+ eval_metric = w2nn.ClippedMSECriterion():cuda()
|
|
|
end
|
|
|
local adam_config = {
|
|
|
xLearningRate = settings.learning_rate,
|