ClippedMSECriterion.lua 1.1 KB

12345678910111213141516171819202122232425262728293031
  1. local ClippedMSECriterion, parent = torch.class('w2nn.ClippedMSECriterion','nn.Criterion')
  2. ClippedMSECriterion.has_instance_loss = true
  3. function ClippedMSECriterion:__init(min, max)
  4. parent.__init(self)
  5. self.min = min or 0
  6. self.max = max or 1
  7. self.diff = torch.Tensor()
  8. self.diff_pow2 = torch.Tensor()
  9. self.instance_loss = {}
  10. end
  11. function ClippedMSECriterion:updateOutput(input, target)
  12. self.diff:resizeAs(input):copy(input)
  13. self.diff:clamp(self.min, self.max)
  14. self.diff:add(-1, target)
  15. self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
  16. self.instance_loss = {}
  17. self.output = 0
  18. local scale = 1.0 / input:size(1)
  19. for i = 1, input:size(1) do
  20. local instance_loss = self.diff_pow2[i]:sum() / self.diff_pow2[i]:nElement()
  21. self.instance_loss[i] = instance_loss
  22. self.output = self.output + instance_loss
  23. end
  24. return self.output / input:size(1)
  25. end
  26. function ClippedMSECriterion:updateGradInput(input, target)
  27. local norm = 1.0 / input:nElement()
  28. self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
  29. return self.gradInput
  30. end