DepthExpand2x.lua 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. if w2nn.DepthExpand2x then
  2. return w2nn.DepthExpand2x
  3. end
  4. local DepthExpand2x, parent = torch.class('w2nn.DepthExpand2x','nn.Module')
  5. function DepthExpand2x:__init()
  6. parent:__init()
  7. end
  8. function DepthExpand2x:updateOutput(input)
  9. local x = input
  10. -- (batch_size, depth, height, width)
  11. self.shape = x:size()
  12. assert(self.shape:size() == 4, "input must be 4d tensor")
  13. assert(self.shape[2] % 4 == 0, "depth must be depth % 4 = 0")
  14. -- (batch_size, width, height, depth)
  15. x = x:transpose(2, 4)
  16. -- (batch_size, width, height * 2, depth / 2)
  17. x = x:reshape(self.shape[1], self.shape[4], self.shape[3] * 2, self.shape[2] / 2)
  18. -- (batch_size, height * 2, width, depth / 2)
  19. x = x:transpose(2, 3)
  20. -- (batch_size, height * 2, width * 2, depth / 4)
  21. x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4] * 2, self.shape[2] / 4)
  22. -- (batch_size, depth / 4, height * 2, width * 2)
  23. x = x:transpose(2, 4)
  24. x = x:transpose(3, 4)
  25. self.output:resizeAs(x):copy(x) -- contiguous
  26. return self.output
  27. end
  28. function DepthExpand2x:updateGradInput(input, gradOutput)
  29. -- (batch_size, depth / 4, height * 2, width * 2)
  30. local x = gradOutput
  31. -- (batch_size, height * 2, width * 2, depth / 4)
  32. x = x:transpose(2, 4)
  33. x = x:transpose(2, 3)
  34. -- (batch_size, height * 2, width, depth / 2)
  35. x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4], self.shape[2] / 2)
  36. -- (batch_size, width, height * 2, depth / 2)
  37. x = x:transpose(2, 3)
  38. -- (batch_size, width, height, depth)
  39. x = x:reshape(self.shape[1], self.shape[4], self.shape[3], self.shape[2])
  40. -- (batch_size, depth, height, width)
  41. x = x:transpose(2, 4)
  42. self.gradInput:resizeAs(x):copy(x)
  43. return self.gradInput
  44. end
  45. function DepthExpand2x.test()
  46. require 'image'
  47. local function show(x)
  48. local img = torch.Tensor(3, x:size(3), x:size(4))
  49. img[1]:copy(x[1][1])
  50. img[2]:copy(x[1][2])
  51. img[3]:copy(x[1][3])
  52. image.display(img)
  53. end
  54. local img = image.lena()
  55. local x = torch.Tensor(1, img:size(1) * 4, img:size(2), img:size(3))
  56. for i = 0, img:size(1) * 4 - 1 do
  57. src_index = ((i % 3) + 1)
  58. x[1][i + 1]:copy(img[src_index])
  59. end
  60. show(x)
  61. local de2x = w2nn.DepthExpand2x()
  62. out = de2x:forward(x)
  63. show(out)
  64. out = de2x:updateGradInput(x, out)
  65. show(out)
  66. end
  67. return DepthExpand2x