RandomBinaryCriterion.lua 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. local RandomBinaryCriterion, parent = torch.class('w2nn.RandomBinaryCriterion','nn.Criterion')
  2. local function create_filters(ch, n, k)
  3. local filter = w2nn.RandomBinaryConvolution(ch, n, k, k)
  4. -- channel identify
  5. for i = 1, ch do
  6. filter.weight[i]:fill(0)
  7. filter.weight[i][i][math.floor(k/2)+1][math.floor(k/2)+1] = 1
  8. end
  9. return filter
  10. end
  11. function RandomBinaryCriterion:__init(ch, n, k)
  12. parent.__init(self)
  13. self.gamma = 0.1
  14. self.n = n or 32
  15. self.k = k or 3
  16. self.ch = ch
  17. self.filter1 = create_filters(self.ch, self.n, self.k)
  18. self.filter2 = self.filter1:clone()
  19. self.diff = torch.Tensor()
  20. self.diff_abs = torch.Tensor()
  21. self.square_loss_buff = torch.Tensor()
  22. self.linear_loss_buff = torch.Tensor()
  23. self.input = torch.Tensor()
  24. self.target = torch.Tensor()
  25. end
  26. function RandomBinaryCriterion:updateOutput(input, target)
  27. if input:dim() == 2 then
  28. local k = math.sqrt(input:size(2) / self.ch)
  29. input = input:reshape(input:size(1), self.ch, k, k)
  30. end
  31. if target:dim() == 2 then
  32. local k = math.sqrt(target:size(2) / self.ch)
  33. target = target:reshape(target:size(1), self.ch, k, k)
  34. end
  35. self.input:resizeAs(input):copy(input):clamp(0, 1)
  36. self.target:resizeAs(target):copy(target):clamp(0, 1)
  37. local lb1 = self.filter1:forward(self.input)
  38. local lb2 = self.filter2:forward(self.target)
  39. -- huber loss
  40. self.diff:resizeAs(lb1):copy(lb1)
  41. for i = 1, lb1:size(1) do
  42. self.diff[i]:add(-1, lb2[i])
  43. end
  44. self.diff_abs:resizeAs(self.diff):copy(self.diff):abs()
  45. local square_targets = self.diff[torch.lt(self.diff_abs, self.gamma)]
  46. local linear_targets = self.diff[torch.ge(self.diff_abs, self.gamma)]
  47. local square_loss = self.square_loss_buff:resizeAs(square_targets):copy(square_targets):pow(2.0):mul(0.5):sum()
  48. local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():add(-0.5 * self.gamma):mul(self.gamma):sum()
  49. --self.outlier_rate = linear_targets:nElement() / input:nElement()
  50. self.output = (square_loss + linear_loss) / lb1:nElement()
  51. return self.output
  52. end
  53. function RandomBinaryCriterion:updateGradInput(input, target)
  54. local d2 = false
  55. if input:dim() == 2 then
  56. d2 = true
  57. local k = math.sqrt(input:size(2) / self.ch)
  58. input = input:reshape(input:size(1), self.ch, k, k)
  59. end
  60. local norm = self.n / self.input:nElement()
  61. self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
  62. local outlier = torch.ge(self.diff_abs, self.gamma)
  63. self.gradInput[outlier] = torch.sign(self.diff[outlier]) * self.gamma * norm
  64. local grad_input = self.filter1:updateGradInput(input, self.gradInput)
  65. if d2 then
  66. grad_input = grad_input:reshape(grad_input:size(1), grad_input:size(2) * grad_input:size(3) * grad_input:size(4))
  67. end
  68. return grad_input
  69. end