| 123456789101112131415161718192021222324252627282930313233343536373839404142 | local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module')function ShakeShakeTable:__init()   parent.__init(self)   self.alpha = torch.Tensor()   self.beta = torch.Tensor()   self.first = torch.Tensor()   self.second = torch.Tensor()   self.train = trueendfunction ShakeShakeTable:updateOutput(input)   local batch_size = input[1]:size(1)   if self.train then      self.alpha:resize(batch_size):uniform()      self.beta:resize(batch_size):uniform()      self.second:resizeAs(input[1]):copy(input[2])      for i = 1, batch_size do	 self.second[i]:mul(self.alpha[i])      end      self.output:resizeAs(input[1]):copy(input[1])      for i = 1, batch_size do	 self.output[i]:mul(1.0 - self.alpha[i])      end      self.output:add(self.second):mul(2)   else      self.output:resizeAs(input[1]):copy(input[1]):add(input[2])   end   return self.outputendfunction ShakeShakeTable:updateGradInput(input, gradOutput)   local batch_size = input[1]:size(1)   self.first:resizeAs(gradOutput):copy(gradOutput)   for i = 1, batch_size do      self.first[i]:mul(self.beta[i])   end   self.second:resizeAs(gradOutput):copy(gradOutput)   for i = 1, batch_size do      self.second[i]:mul(1.0 - self.beta[i])   end   self.gradOutput = {self.first, self.second}   return self.gradOutputend
 |