WeightedMSECriterion.lua 712 B

123456789101112131415161718192021222324
  1. local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion')
  2. function WeightedMSECriterion:__init(w)
  3. parent.__init(self)
  4. self.weight = w:clone()
  5. self.diff = torch.Tensor()
  6. self.loss = torch.Tensor()
  7. end
  8. function WeightedMSECriterion:updateOutput(input, target)
  9. self.diff:resizeAs(input):copy(input)
  10. for i = 1, input:size(1) do
  11. self.diff[i]:add(-1, target[i]):cmul(self.weight)
  12. end
  13. self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)
  14. self.output = self.loss:mean()
  15. return self.output
  16. end
  17. function WeightedMSECriterion:updateGradInput(input, target)
  18. self.gradInput:resizeAs(input):copy(self.diff)
  19. return self.gradInput
  20. end