|
@@ -390,27 +390,27 @@ local function create_criterion(model)
|
|
return aux:cuda()
|
|
return aux:cuda()
|
|
elseif settings.loss == "lbp" then
|
|
elseif settings.loss == "lbp" then
|
|
if reconstruct.is_rgb(model) then
|
|
if reconstruct.is_rgb(model) then
|
|
- return w2nn.RandomBinaryCriterion(3, 128):cuda()
|
|
|
|
|
|
+ return w2nn.LBPCriterion(3, 128):cuda()
|
|
else
|
|
else
|
|
- return w2nn.RandomBinaryCriterion(1, 128):cuda()
|
|
|
|
|
|
+ return w2nn.LBPCriterion(1, 128):cuda()
|
|
end
|
|
end
|
|
elseif settings.loss == "lbp2" then
|
|
elseif settings.loss == "lbp2" then
|
|
if reconstruct.is_rgb(model) then
|
|
if reconstruct.is_rgb(model) then
|
|
- return w2nn.RandomBinaryCriterion(3, 128, 3, 2):cuda()
|
|
|
|
|
|
+ return w2nn.LBPCriterion(3, 128, 3, 2):cuda()
|
|
else
|
|
else
|
|
- return w2nn.RandomBinaryCriterion(1, 128, 3, 2):cuda()
|
|
|
|
|
|
+ return w2nn.LBPCriterion(1, 128, 3, 2):cuda()
|
|
end
|
|
end
|
|
elseif settings.loss == "aux_lbp" then
|
|
elseif settings.loss == "aux_lbp" then
|
|
if reconstruct.is_rgb(model) then
|
|
if reconstruct.is_rgb(model) then
|
|
- return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128}):cuda()
|
|
|
|
|
|
+ return w2nn.AuxiliaryLossCriterion(w2nn.LBPCriterion, {3, 128}):cuda()
|
|
else
|
|
else
|
|
- return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128}):cuda()
|
|
|
|
|
|
+ return w2nn.AuxiliaryLossCriterion(w2nn.LBPCriterion, {1, 128}):cuda()
|
|
end
|
|
end
|
|
elseif settings.loss == "aux_lbp2" then
|
|
elseif settings.loss == "aux_lbp2" then
|
|
if reconstruct.is_rgb(model) then
|
|
if reconstruct.is_rgb(model) then
|
|
- return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 128, 3, 2}):cuda()
|
|
|
|
|
|
+ return w2nn.AuxiliaryLossCriterion(w2nn.LBPCriterion, {3, 128, 3, 2}):cuda()
|
|
else
|
|
else
|
|
- return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 128, 3, 2}):cuda()
|
|
|
|
|
|
+ return w2nn.AuxiliaryLossCriterion(w2nn.LBPCriterion, {1, 128, 3, 2}):cuda()
|
|
end
|
|
end
|
|
else
|
|
else
|
|
error("unsupported loss .." .. settings.loss)
|
|
error("unsupported loss .." .. settings.loss)
|