ShakeShakeTable.lua 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module')
  2. function ShakeShakeTable:__init()
  3. parent.__init(self)
  4. self.alpha = torch.Tensor()
  5. self.beta = torch.Tensor()
  6. self.first = torch.Tensor()
  7. self.second = torch.Tensor()
  8. self.train = true
  9. end
  10. function ShakeShakeTable:updateOutput(input)
  11. local batch_size = input[1]:size(1)
  12. if self.train then
  13. self.alpha:resize(batch_size):uniform()
  14. self.beta:resize(batch_size):uniform()
  15. self.second:resizeAs(input[1]):copy(input[2])
  16. for i = 1, batch_size do
  17. self.second[i]:mul(self.alpha[i])
  18. end
  19. self.output:resizeAs(input[1]):copy(input[1])
  20. for i = 1, batch_size do
  21. self.output[i]:mul(1.0 - self.alpha[i])
  22. end
  23. self.output:add(self.second):mul(2)
  24. else
  25. self.output:resizeAs(input[1]):copy(input[1]):add(input[2])
  26. end
  27. return self.output
  28. end
  29. function ShakeShakeTable:updateGradInput(input, gradOutput)
  30. local batch_size = input[1]:size(1)
  31. self.first:resizeAs(gradOutput):copy(gradOutput)
  32. for i = 1, batch_size do
  33. self.first[i]:mul(self.beta[i])
  34. end
  35. self.second:resizeAs(gradOutput):copy(gradOutput)
  36. for i = 1, batch_size do
  37. self.second[i]:mul(1.0 - self.beta[i])
  38. end
  39. self.gradOutput = {self.first, self.second}
  40. return self.gradOutput
  41. end