srcnn.lua 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. require './LeakyReLU'
  2. function nn.SpatialConvolutionMM:reset(stdv)
  3. stdv = math.sqrt(2 / ( self.kW * self.kH * self.nOutputPlane))
  4. self.weight:normal(0, stdv)
  5. self.bias:fill(0)
  6. end
  7. local srcnn = {}
  8. function srcnn.waifu2x()
  9. local model = nn.Sequential()
  10. model:add(nn.SpatialConvolutionMM(1, 32, 3, 3, 1, 1, 0, 0))
  11. model:add(nn.LeakyReLU(0.1))
  12. model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
  13. model:add(nn.LeakyReLU(0.1))
  14. model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
  15. model:add(nn.LeakyReLU(0.1))
  16. model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
  17. model:add(nn.LeakyReLU(0.1))
  18. model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
  19. model:add(nn.LeakyReLU(0.1))
  20. model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
  21. model:add(nn.LeakyReLU(0.1))
  22. model:add(nn.SpatialConvolutionMM(128, 1, 3, 3, 1, 1, 0, 0))
  23. model:add(nn.View(-1):setNumInputDims(3))
  24. --model:cuda()
  25. --print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
  26. return model, 7
  27. end
  28. -- current 4x is worse then 2x * 2
  29. function srcnn.waifu4x()
  30. local model = nn.Sequential()
  31. model:add(nn.SpatialConvolutionMM(1, 32, 9, 9, 1, 1, 0, 0))
  32. model:add(nn.LeakyReLU(0.1))
  33. model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
  34. model:add(nn.LeakyReLU(0.1))
  35. model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
  36. model:add(nn.LeakyReLU(0.1))
  37. model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
  38. model:add(nn.LeakyReLU(0.1))
  39. model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
  40. model:add(nn.LeakyReLU(0.1))
  41. model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
  42. model:add(nn.LeakyReLU(0.1))
  43. model:add(nn.SpatialConvolutionMM(128, 1, 5, 5, 1, 1, 0, 0))
  44. model:add(nn.View(-1):setNumInputDims(3))
  45. return model, 13
  46. end
  47. return srcnn