srcnn.lua 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. function nn.SpatialConvolutionMM:reset(stdv)
  6. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  7. self.weight:normal(0, stdv)
  8. self.bias:zero()
  9. end
  10. if cudnn and cudnn.SpatialConvolution then
  11. function cudnn.SpatialConvolution:reset(stdv)
  12. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  13. self.weight:normal(0, stdv)
  14. self.bias:zero()
  15. end
  16. end
  17. function nn.SpatialConvolutionMM:clearState()
  18. if self.gradWeight then
  19. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  20. end
  21. if self.gradBias then
  22. self.gradBias:resize(self.nOutputPlane):zero()
  23. end
  24. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  25. end
  26. function srcnn.channels(model)
  27. return model:get(model:size() - 1).weight:size(1)
  28. end
  29. function srcnn.backend(model)
  30. local conv = model:findModules("cudnn.SpatialConvolution")
  31. if #conv > 0 then
  32. return "cudnn"
  33. else
  34. return "cunn"
  35. end
  36. end
  37. function srcnn.color(model)
  38. local ch = srcnn.channels(model)
  39. if ch == 3 then
  40. return "rgb"
  41. else
  42. return "y"
  43. end
  44. end
  45. function srcnn.name(model)
  46. local backend_cudnn = false
  47. local conv = model:findModules("nn.SpatialConvolutionMM")
  48. if #conv == 0 then
  49. backend_cudnn = true
  50. conv = model:findModules("cudnn.SpatialConvolution")
  51. end
  52. if #conv == 7 then
  53. return "vgg_7"
  54. elseif #conv == 12 then
  55. return "vgg_12"
  56. else
  57. return nil
  58. end
  59. end
  60. function srcnn.offset_size(model)
  61. local conv = model:findModules("nn.SpatialConvolutionMM")
  62. if #conv == 0 then
  63. conv = model:findModules("cudnn.SpatialConvolution")
  64. end
  65. local offset = 0
  66. for i = 1, #conv do
  67. offset = offset + (conv[i].kW - 1) / 2
  68. end
  69. return math.floor(offset)
  70. end
  71. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  72. if backend == "cunn" then
  73. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  74. elseif backend == "cudnn" then
  75. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  76. else
  77. error("unsupported backend:" .. backend)
  78. end
  79. end
  80. -- VGG style net(7 layers)
  81. function srcnn.vgg_7(backend, ch)
  82. local model = nn.Sequential()
  83. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  84. model:add(w2nn.LeakyReLU(0.1))
  85. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  86. model:add(w2nn.LeakyReLU(0.1))
  87. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  88. model:add(w2nn.LeakyReLU(0.1))
  89. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  90. model:add(w2nn.LeakyReLU(0.1))
  91. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  92. model:add(w2nn.LeakyReLU(0.1))
  93. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  94. model:add(w2nn.LeakyReLU(0.1))
  95. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  96. model:add(nn.View(-1):setNumInputDims(3))
  97. --model:cuda()
  98. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  99. return model
  100. end
  101. -- VGG style net(12 layers)
  102. function srcnn.vgg_12(backend, ch)
  103. local model = nn.Sequential()
  104. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  105. model:add(w2nn.LeakyReLU(0.1))
  106. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  107. model:add(w2nn.LeakyReLU(0.1))
  108. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  109. model:add(w2nn.LeakyReLU(0.1))
  110. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  111. model:add(w2nn.LeakyReLU(0.1))
  112. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  113. model:add(w2nn.LeakyReLU(0.1))
  114. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  115. model:add(w2nn.LeakyReLU(0.1))
  116. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  117. model:add(w2nn.LeakyReLU(0.1))
  118. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  119. model:add(w2nn.LeakyReLU(0.1))
  120. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  121. model:add(w2nn.LeakyReLU(0.1))
  122. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  123. model:add(w2nn.LeakyReLU(0.1))
  124. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  125. model:add(w2nn.LeakyReLU(0.1))
  126. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  127. model:add(nn.View(-1):setNumInputDims(3))
  128. --model:cuda()
  129. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  130. return model
  131. end
  132. function srcnn.create(model_name, backend, color)
  133. model_name = model_name or "vgg_7"
  134. backend = backend or "cunn"
  135. color = color or "rgb"
  136. local ch = 3
  137. if color == "rgb" then
  138. ch = 3
  139. elseif color == "y" then
  140. ch = 1
  141. else
  142. error("unsupported color: " .. color)
  143. end
  144. if model_name == "vgg_7" then
  145. return srcnn.vgg_7(backend, ch)
  146. elseif model_name == "vgg_12" then
  147. return srcnn.vgg_12(backend, ch)
  148. else
  149. error("unsupported model_name: " .. model_name)
  150. end
  151. end
  152. return srcnn