srcnn.lua 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. require './LeakyReLU'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. function nn.SpatialConvolutionMM:reset(stdv)
  4. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  5. self.weight:normal(0, stdv)
  6. self.bias:zero()
  7. end
  8. -- ref: http://arxiv.org/abs/1501.00092
  9. local srcnn = {}
  10. function srcnn.waifu2x(color)
  11. local model = nn.Sequential()
  12. local ch = nil
  13. if color == "rgb" then
  14. ch = 3
  15. elseif color == "y" then
  16. ch = 1
  17. else
  18. if color then
  19. error("unknown color: " .. color)
  20. else
  21. error("unknown color: nil")
  22. end
  23. end
  24. -- very deep model
  25. model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
  26. model:add(nn.LeakyReLU(0.1))
  27. model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
  28. model:add(nn.LeakyReLU(0.1))
  29. model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
  30. model:add(nn.LeakyReLU(0.1))
  31. model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
  32. model:add(nn.LeakyReLU(0.1))
  33. model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
  34. model:add(nn.LeakyReLU(0.1))
  35. model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
  36. model:add(nn.LeakyReLU(0.1))
  37. model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
  38. model:add(nn.View(-1):setNumInputDims(3))
  39. --model:cuda()
  40. --print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
  41. return model, 7
  42. end
  43. -- current 4x is worse then 2x * 2
  44. function srcnn.waifu4x(color)
  45. local model = nn.Sequential()
  46. local ch = nil
  47. if color == "rgb" then
  48. ch = 3
  49. elseif color == "y" then
  50. ch = 1
  51. else
  52. error("unknown color: " .. color)
  53. end
  54. model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
  55. model:add(nn.LeakyReLU(0.1))
  56. model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
  57. model:add(nn.LeakyReLU(0.1))
  58. model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
  59. model:add(nn.LeakyReLU(0.1))
  60. model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
  61. model:add(nn.LeakyReLU(0.1))
  62. model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
  63. model:add(nn.LeakyReLU(0.1))
  64. model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
  65. model:add(nn.LeakyReLU(0.1))
  66. model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
  67. model:add(nn.View(-1):setNumInputDims(3))
  68. return model, 13
  69. end
  70. return srcnn