|
@@ -2,8 +2,8 @@ local ClippedMSECriterion, parent = torch.class('w2nn.ClippedMSECriterion','nn.C
|
|
|
|
|
|
function ClippedMSECriterion:__init(min, max)
|
|
function ClippedMSECriterion:__init(min, max)
|
|
parent.__init(self)
|
|
parent.__init(self)
|
|
- self.min = min
|
|
|
|
- self.max = max
|
|
|
|
|
|
+ self.min = min or 0
|
|
|
|
+ self.max = max or 1
|
|
self.diff = torch.Tensor()
|
|
self.diff = torch.Tensor()
|
|
self.diff_pow2 = torch.Tensor()
|
|
self.diff_pow2 = torch.Tensor()
|
|
end
|
|
end
|