LBPCriterion.lua 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. -- Random Generated Local Binary Pattern Loss
  2. local LBPCriterion, parent = torch.class('w2nn.LBPCriterion','nn.Criterion')
  3. local function create_filters(ch, n, k, layers)
  4. local model = nn.Sequential()
  5. for i = 1, layers do
  6. local n_input = ch
  7. if i > 1 then
  8. n_input = n
  9. end
  10. local filter = w2nn.RandomBinaryConvolution(n_input, n, k, k)
  11. if i == 1 then
  12. -- channel identity
  13. for j = 1, ch do
  14. filter.weight[j]:fill(0)
  15. filter.weight[j][j][math.floor(k/2)+1][math.floor(k/2)+1] = 1
  16. end
  17. end
  18. model:add(filter)
  19. --if layers > 1 and i ~= layers then
  20. -- model:add(nn.Sigmoid(true))
  21. --end
  22. end
  23. return model
  24. end
  25. function LBPCriterion:__init(ch, n, k, layers)
  26. parent.__init(self)
  27. self.layers = layers or 1
  28. self.gamma = 0.1
  29. self.n = n or 128
  30. self.k = k or 3
  31. self.ch = ch
  32. self.filter1 = create_filters(self.ch, self.n, self.k, self.layers)
  33. self.filter2 = self.filter1:clone()
  34. self.diff = torch.Tensor()
  35. self.diff_abs = torch.Tensor()
  36. self.square_loss_buff = torch.Tensor()
  37. self.linear_loss_buff = torch.Tensor()
  38. self.input = torch.Tensor()
  39. self.target = torch.Tensor()
  40. end
  41. function LBPCriterion:updateOutput(input, target)
  42. if input:dim() == 2 then
  43. local k = math.sqrt(input:size(2) / self.ch)
  44. input = input:reshape(input:size(1), self.ch, k, k)
  45. end
  46. if target:dim() == 2 then
  47. local k = math.sqrt(target:size(2) / self.ch)
  48. target = target:reshape(target:size(1), self.ch, k, k)
  49. end
  50. self.input:resizeAs(input):copy(input):clamp(0, 1)
  51. self.target:resizeAs(target):copy(target):clamp(0, 1)
  52. local lb1 = self.filter1:forward(self.input)
  53. local lb2 = self.filter2:forward(self.target)
  54. -- huber loss
  55. self.diff:resizeAs(lb1):copy(lb1)
  56. self.diff:add(-1, lb2)
  57. self.diff_abs:resizeAs(self.diff):copy(self.diff):abs()
  58. local square_targets = self.diff[torch.lt(self.diff_abs, self.gamma)]
  59. local linear_targets = self.diff[torch.ge(self.diff_abs, self.gamma)]
  60. local square_loss = self.square_loss_buff:resizeAs(square_targets):copy(square_targets):pow(2.0):mul(0.5):sum()
  61. local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():add(-0.5 * self.gamma):mul(self.gamma):sum()
  62. --self.outlier_rate = linear_targets:nElement() / input:nElement()
  63. self.output = (square_loss + linear_loss) / lb1:nElement()
  64. return self.output
  65. end
  66. function LBPCriterion:updateGradInput(input, target)
  67. local d2 = false
  68. if input:dim() == 2 then
  69. d2 = true
  70. local k = math.sqrt(input:size(2) / self.ch)
  71. input = input:reshape(input:size(1), self.ch, k, k)
  72. end
  73. local norm = self.n / self.input:nElement()
  74. self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
  75. local outlier = torch.ge(self.diff_abs, self.gamma)
  76. self.gradInput[outlier] = torch.sign(self.diff[outlier]) * self.gamma * norm
  77. local grad_input = self.filter1:updateGradInput(input, self.gradInput)
  78. if d2 then
  79. grad_input = grad_input:reshape(grad_input:size(1), grad_input:size(2) * grad_input:size(3) * grad_input:size(4))
  80. end
  81. return grad_input
  82. end