瀏覽代碼

Fix a bug in ClippedMSECriterion

nagadomi 9 年之前
父節點
當前提交
9ec1f5159b
共有 2 個文件被更改,包括 2 次插入3 次删除
  1. 1 1
      lib/ClippedMSECriterion.lua
  2. 1 2
      lib/ClippedWeightedHuberCriterion.lua

+ 1 - 1
lib/ClippedMSECriterion.lua

@@ -8,7 +8,7 @@ function ClippedMSECriterion:__init(min, max)
 end
 function ClippedMSECriterion:updateOutput(input, target)
    self.diff:resizeAs(input):copy(input)
-   self.diff[torch.lt(self.diff, self.min)]:clamp(self.min, self.max)
+   self.diff:clamp(self.min, self.max)
    self.diff:add(-1, target)
    self.output = self.diff:pow(2):sum() / input:nElement()
    return self.output

+ 1 - 2
lib/ClippedWeightedHuberCriterion.lua

@@ -14,8 +14,7 @@ function ClippedWeightedHuberCriterion:__init(w, gamma, clip)
 end
 function ClippedWeightedHuberCriterion:updateOutput(input, target)
    self.diff:resizeAs(input):copy(input)
-   self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1]
-   self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2]
+   self.diff:clamp(self.clip[1], self.clip[2])
    for i = 1, input:size(1) do
       self.diff[i]:add(-1, target[i]):cmul(self.weight)
    end