SSIMCriterion.lua 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. -- SSIM Index, ref: http://www.cns.nyu.edu/~lcv/ssim/ssim_index.m
  2. local SSIMCriterion, parent = torch.class('w2nn.SSIMCriterion','nn.Criterion')
  3. function SSIMCriterion:__init(ch, kernel_size, sigma)
  4. parent.__init(self)
  5. local function gaussian2d(kernel_size, sigma)
  6. sigma = sigma or 1
  7. local kernel = torch.Tensor(kernel_size, kernel_size)
  8. local u = math.floor(kernel_size / 2) + 1
  9. local amp = (1 / math.sqrt(2 * math.pi * sigma^2))
  10. for x = 1, kernel_size do
  11. for y = 1, kernel_size do
  12. kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2))
  13. end
  14. end
  15. kernel:div(kernel:sum())
  16. return kernel
  17. end
  18. ch = ch or 1
  19. kernel_size = kernel_size or 11
  20. sigma = sigma or 1.5
  21. local kernel = gaussian2d(kernel_size, sigma)
  22. if ch > 1 then
  23. local kernel_nd = torch.Tensor(ch, ch, kernel_size, kernel_size)
  24. for i = 1, ch do
  25. for j = 1, ch do
  26. kernel_nd[i][j]:copy(kernel)
  27. if i ~= j then
  28. kernel_nd[i][j]:zero()
  29. end
  30. end
  31. end
  32. kernel = kernel_nd
  33. end
  34. self.c1 = 0.01^2
  35. self.c2 = 0.03^2
  36. self.ch = ch
  37. self.conv = nn.SpatialConvolution(ch, ch, kernel_size, kernel_size, 1, 1, 0, 0):noBias()
  38. self.conv.weight:copy(kernel)
  39. self.mu1 = torch.Tensor()
  40. self.mu2 = torch.Tensor()
  41. self.mu1_sq = torch.Tensor()
  42. self.mu2_sq = torch.Tensor()
  43. self.mu1_mu2 = torch.Tensor()
  44. self.sigma1_sq = torch.Tensor()
  45. self.sigma2_sq = torch.Tensor()
  46. self.sigma12 = torch.Tensor()
  47. self.ssim_map = torch.Tensor()
  48. end
  49. function SSIMCriterion:updateOutput(input, target)-- dynamic range: 0-1
  50. assert(input:nElement() == target:nElement())
  51. local valid = self.conv:forward(input)
  52. self.mu1:resizeAs(valid):copy(valid)
  53. self.mu2:resizeAs(valid):copy(self.conv:forward(target))
  54. self.mu1_sq:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu1)
  55. self.mu2_sq:resizeAs(self.mu2):copy(self.mu2):cmul(self.mu2)
  56. self.mu1_mu2:resizeAs(self.mu1):copy(self.mu1):cmul(self.mu2)
  57. self.sigma1_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, input)):add(-1, self.mu1_sq))
  58. self.sigma2_sq:resizeAs(valid):copy(self.conv:forward(torch.cmul(target, target)):add(-1, self.mu2_sq))
  59. self.sigma12:resizeAs(valid):copy(self.conv:forward(torch.cmul(input, target)):add(-1, self.mu1_mu2))
  60. local ssim = self.mu1_mu2:mul(2):add(self.c1):cmul(self.sigma12:mul(2):add(self.c2)):
  61. cdiv(self.mu1_sq:add(self.mu2_sq):add(self.c1):cmul(self.sigma1_sq:add(self.sigma2_sq):add(self.c2))):mean()
  62. return ssim
  63. end
  64. function SSIMCriterion:updateGradInput(input, target)
  65. error("not implemented")
  66. end