srcnn.lua 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. function nn.SpatialConvolutionMM:reset(stdv)
  6. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  7. self.weight:normal(0, stdv)
  8. self.bias:zero()
  9. end
  10. if cudnn and cudnn.SpatialConvolution then
  11. function cudnn.SpatialConvolution:reset(stdv)
  12. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  13. self.weight:normal(0, stdv)
  14. self.bias:zero()
  15. end
  16. end
  17. function nn.SpatialConvolutionMM:clearState()
  18. if self.gradWeight then
  19. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  20. end
  21. if self.gradBias then
  22. self.gradBias:resize(self.nOutputPlane):zero()
  23. end
  24. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  25. end
  26. function srcnn.channels(model)
  27. return model:get(model:size() - 1).weight:size(1)
  28. end
  29. function srcnn.waifu2x_cunn(ch)
  30. local model = nn.Sequential()
  31. model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
  32. model:add(w2nn.LeakyReLU(0.1))
  33. model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
  34. model:add(w2nn.LeakyReLU(0.1))
  35. model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
  36. model:add(w2nn.LeakyReLU(0.1))
  37. model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
  38. model:add(w2nn.LeakyReLU(0.1))
  39. model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
  40. model:add(w2nn.LeakyReLU(0.1))
  41. model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
  42. model:add(w2nn.LeakyReLU(0.1))
  43. model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
  44. model:add(nn.View(-1):setNumInputDims(3))
  45. --model:cuda()
  46. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  47. return model
  48. end
  49. function srcnn.waifu2x_cudnn(ch)
  50. local model = nn.Sequential()
  51. model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
  52. model:add(w2nn.LeakyReLU(0.1))
  53. model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
  54. model:add(w2nn.LeakyReLU(0.1))
  55. model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
  56. model:add(w2nn.LeakyReLU(0.1))
  57. model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
  58. model:add(w2nn.LeakyReLU(0.1))
  59. model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
  60. model:add(w2nn.LeakyReLU(0.1))
  61. model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
  62. model:add(w2nn.LeakyReLU(0.1))
  63. model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
  64. model:add(nn.View(-1):setNumInputDims(3))
  65. --model:cuda()
  66. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  67. return model
  68. end
  69. function srcnn.create(model_name, backend, color)
  70. local ch = 3
  71. if color == "rgb" then
  72. ch = 3
  73. elseif color == "y" then
  74. ch = 1
  75. else
  76. error("unsupported color: " + color)
  77. end
  78. if backend == "cunn" then
  79. return srcnn.waifu2x_cunn(ch)
  80. elseif backend == "cudnn" then
  81. return srcnn.waifu2x_cudnn(ch)
  82. else
  83. error("unsupported backend: " + backend)
  84. end
  85. end
  86. return srcnn