| 12345678910111213141516171819202122232425 | local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion')function WeightedMSECriterion:__init(w)   parent.__init(self)   self.weight = w:clone()   self.diff = torch.Tensor()   self.loss = torch.Tensor()endfunction WeightedMSECriterion:updateOutput(input, target)   self.diff:resizeAs(input):copy(input)   for i = 1, input:size(1) do      self.diff[i]:add(-1, target[i]):cmul(self.weight)   end   self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)   self.output = self.loss:mean()      return self.outputendfunction WeightedMSECriterion:updateGradInput(input, target)   local norm = 2.0 / input:nElement()   self.gradInput:resizeAs(input):copy(self.diff):mul(norm)   return self.gradInputend
 |