PSNRCriterion.lua 636 B

12345678910111213141516171819
  1. local PSNRCriterion, parent = torch.class('w2nn.PSNRCriterion','nn.Criterion')
  2. function PSNRCriterion:__init()
  3. parent.__init(self)
  4. self.image = torch.Tensor()
  5. self.diff = torch.Tensor()
  6. end
  7. function PSNRCriterion:updateOutput(input, target)
  8. self.image:resizeAs(input):copy(input)
  9. self.image:clamp(0.0, 1.0)
  10. self.diff:resizeAs(self.image):copy(self.image)
  11. local mse = math.max(self.diff:add(-1, target):pow(2):mean(), (0.1/255)^2)
  12. self.output = 10 * math.log10(1.0 / mse)
  13. return self.output
  14. end
  15. function PSNRCriterion:updateGradInput(input, target)
  16. error("PSNRCriterion does not support backward")
  17. end