srcnn.lua 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. require './LeakyReLU'
  2. function nn.SpatialConvolutionMM:reset(stdv)
  3. stdv = math.sqrt(2 / ( self.kW * self.kH * self.nOutputPlane))
  4. self.weight:normal(0, stdv)
  5. self.bias:fill(0)
  6. end
  7. local srcnn = {}
  8. function srcnn.waifu2x(color)
  9. local model = nn.Sequential()
  10. local ch = nil
  11. if color == "rgb" then
  12. ch = 3
  13. elseif color == "y" then
  14. ch = 1
  15. else
  16. if color then
  17. error("unknown color: " .. color)
  18. else
  19. error("unknown color: nil")
  20. end
  21. end
  22. model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
  23. model:add(nn.LeakyReLU(0.1))
  24. model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
  25. model:add(nn.LeakyReLU(0.1))
  26. model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
  27. model:add(nn.LeakyReLU(0.1))
  28. model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
  29. model:add(nn.LeakyReLU(0.1))
  30. model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
  31. model:add(nn.LeakyReLU(0.1))
  32. model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
  33. model:add(nn.LeakyReLU(0.1))
  34. model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
  35. model:add(nn.View(-1):setNumInputDims(3))
  36. --model:cuda()
  37. --print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
  38. return model, 7
  39. end
  40. -- current 4x is worse then 2x * 2
  41. function srcnn.waifu4x(color)
  42. local model = nn.Sequential()
  43. local ch = nil
  44. if color == "rgb" then
  45. ch = 3
  46. elseif color == "y" then
  47. ch = 1
  48. else
  49. error("unknown color: " .. color)
  50. end
  51. model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
  52. model:add(nn.LeakyReLU(0.1))
  53. model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
  54. model:add(nn.LeakyReLU(0.1))
  55. model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
  56. model:add(nn.LeakyReLU(0.1))
  57. model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
  58. model:add(nn.LeakyReLU(0.1))
  59. model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
  60. model:add(nn.LeakyReLU(0.1))
  61. model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
  62. model:add(nn.LeakyReLU(0.1))
  63. model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
  64. model:add(nn.View(-1):setNumInputDims(3))
  65. return model, 13
  66. end
  67. return srcnn