srcnn.lua 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. require './mynn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. function srcnn.waifu2x_cunn(ch)
  6. local model = nn.Sequential()
  7. model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
  8. model:add(mynn.LeakyReLU(0.1))
  9. model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
  10. model:add(mynn.LeakyReLU(0.1))
  11. model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
  12. model:add(mynn.LeakyReLU(0.1))
  13. model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
  14. model:add(mynn.LeakyReLU(0.1))
  15. model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
  16. model:add(mynn.LeakyReLU(0.1))
  17. model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
  18. model:add(mynn.LeakyReLU(0.1))
  19. model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
  20. model:add(nn.View(-1):setNumInputDims(3))
  21. --model:cuda()
  22. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  23. return model
  24. end
  25. function srcnn.waifu2x_cudnn(ch)
  26. local model = nn.Sequential()
  27. model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
  28. model:add(mynn.LeakyReLU(0.1))
  29. model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
  30. model:add(mynn.LeakyReLU(0.1))
  31. model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
  32. model:add(mynn.LeakyReLU(0.1))
  33. model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
  34. model:add(mynn.LeakyReLU(0.1))
  35. model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
  36. model:add(mynn.LeakyReLU(0.1))
  37. model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
  38. model:add(mynn.LeakyReLU(0.1))
  39. model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
  40. model:add(nn.View(-1):setNumInputDims(3))
  41. --model:cuda()
  42. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  43. return model
  44. end
  45. function srcnn.create(model_name, backend, color)
  46. local ch = 3
  47. if color == "rgb" then
  48. ch = 3
  49. elseif color == "y" then
  50. ch = 1
  51. else
  52. error("unsupported color: " + color)
  53. end
  54. if backend == "cunn" then
  55. return srcnn.waifu2x_cunn(ch)
  56. elseif backend == "cudnn" then
  57. return srcnn.waifu2x_cudnn(ch)
  58. else
  59. error("unsupported backend: " + backend)
  60. end
  61. end
  62. return srcnn