ClippedMSECriterion.lua 779 B

12345678910111213141516171819202122
  1. local ClippedMSECriterion, parent = torch.class('w2nn.ClippedMSECriterion','nn.Criterion')
  2. function ClippedMSECriterion:__init(min, max)
  3. parent.__init(self)
  4. self.min = min or 0
  5. self.max = max or 1
  6. self.diff = torch.Tensor()
  7. self.diff_pow2 = torch.Tensor()
  8. end
  9. function ClippedMSECriterion:updateOutput(input, target)
  10. self.diff:resizeAs(input):copy(input)
  11. self.diff:clamp(self.min, self.max)
  12. self.diff:add(-1, target)
  13. self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
  14. self.output = self.diff_pow2:sum() / input:nElement()
  15. return self.output
  16. end
  17. function ClippedMSECriterion:updateGradInput(input, target)
  18. local norm = 1.0 / input:nElement()
  19. self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
  20. return self.gradInput
  21. end