Procházet zdrojové kódy

Fix a bug in ClippedMSECriterion

nagadomi před 9 roky
rodič
revize
9ec1f5159b

+ 1 - 1
lib/ClippedMSECriterion.lua

@@ -8,7 +8,7 @@ function ClippedMSECriterion:__init(min, max)
 end
 function ClippedMSECriterion:updateOutput(input, target)
    self.diff:resizeAs(input):copy(input)
-   self.diff[torch.lt(self.diff, self.min)]:clamp(self.min, self.max)
+   self.diff:clamp(self.min, self.max)
    self.diff:add(-1, target)
    self.output = self.diff:pow(2):sum() / input:nElement()
    return self.output

+ 1 - 2
lib/ClippedWeightedHuberCriterion.lua

@@ -14,8 +14,7 @@ function ClippedWeightedHuberCriterion:__init(w, gamma, clip)
 end
 function ClippedWeightedHuberCriterion:updateOutput(input, target)
    self.diff:resizeAs(input):copy(input)
-   self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1]
-   self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2]
+   self.diff:clamp(self.clip[1], self.clip[2])
    for i = 1, input:size(1) do
       self.diff[i]:add(-1, target[i]):cmul(self.weight)
    end