srcnn.lua 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. function srcnn.channels(model)
  6. return model:get(model:size() - 1).weight:size(1)
  7. end
  8. function srcnn.waifu2x_cunn(ch)
  9. local model = nn.Sequential()
  10. model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
  11. model:add(w2nn.LeakyReLU(0.1))
  12. model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
  13. model:add(w2nn.LeakyReLU(0.1))
  14. model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
  15. model:add(w2nn.LeakyReLU(0.1))
  16. model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
  17. model:add(w2nn.LeakyReLU(0.1))
  18. model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
  19. model:add(w2nn.LeakyReLU(0.1))
  20. model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
  21. model:add(w2nn.LeakyReLU(0.1))
  22. model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
  23. model:add(nn.View(-1):setNumInputDims(3))
  24. --model:cuda()
  25. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  26. return model
  27. end
  28. function srcnn.waifu2x_cudnn(ch)
  29. local model = nn.Sequential()
  30. model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
  31. model:add(w2nn.LeakyReLU(0.1))
  32. model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
  33. model:add(w2nn.LeakyReLU(0.1))
  34. model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
  35. model:add(w2nn.LeakyReLU(0.1))
  36. model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
  37. model:add(w2nn.LeakyReLU(0.1))
  38. model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
  39. model:add(w2nn.LeakyReLU(0.1))
  40. model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
  41. model:add(w2nn.LeakyReLU(0.1))
  42. model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
  43. model:add(nn.View(-1):setNumInputDims(3))
  44. --model:cuda()
  45. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  46. return model
  47. end
  48. function srcnn.create(model_name, backend, color)
  49. local ch = 3
  50. if color == "rgb" then
  51. ch = 3
  52. elseif color == "y" then
  53. ch = 1
  54. else
  55. error("unsupported color: " + color)
  56. end
  57. if backend == "cunn" then
  58. return srcnn.waifu2x_cunn(ch)
  59. elseif backend == "cudnn" then
  60. return srcnn.waifu2x_cudnn(ch)
  61. else
  62. error("unsupported backend: " + backend)
  63. end
  64. end
  65. return srcnn