瀏覽代碼

Fix a bug in ClippedMSECriterion:backward

nagadomi 8 年之前
父節點
當前提交
6051dad793
共有 1 個文件被更改,包括 3 次插入1 次删除
  1. 3 1
      lib/ClippedMSECriterion.lua

+ 3 - 1
lib/ClippedMSECriterion.lua

@@ -5,12 +5,14 @@ function ClippedMSECriterion:__init(min, max)
    self.min = min
    self.min = min
    self.max = max
    self.max = max
    self.diff = torch.Tensor()
    self.diff = torch.Tensor()
+   self.diff_pow2 = torch.Tensor()
 end
 end
 function ClippedMSECriterion:updateOutput(input, target)
 function ClippedMSECriterion:updateOutput(input, target)
    self.diff:resizeAs(input):copy(input)
    self.diff:resizeAs(input):copy(input)
    self.diff:clamp(self.min, self.max)
    self.diff:clamp(self.min, self.max)
    self.diff:add(-1, target)
    self.diff:add(-1, target)
-   self.output = self.diff:pow(2):sum() / input:nElement()
+   self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
+   self.output = self.diff_pow2:sum() / input:nElement()
    return self.output
    return self.output
 end
 end
 function ClippedMSECriterion:updateGradInput(input, target)
 function ClippedMSECriterion:updateGradInput(input, target)