srcnn.lua 10 KB


  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. function nn.SpatialConvolutionMM:reset(stdv)
  6. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  7. self.weight:normal(0, stdv)
  8. self.bias:zero()
  9. end
  10. function nn.SpatialFullConvolution:reset(stdv)
  11. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  12. self.weight:normal(0, stdv)
  13. self.bias:zero()
  14. end
  15. if cudnn and cudnn.SpatialConvolution then
  16. function cudnn.SpatialConvolution:reset(stdv)
  17. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  18. self.weight:normal(0, stdv)
  19. self.bias:zero()
  20. end
  21. function cudnn.SpatialFullConvolution:reset(stdv)
  22. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  23. self.weight:normal(0, stdv)
  24. self.bias:zero()
  25. end
  26. end
  27. function nn.SpatialConvolutionMM:clearState()
  28. if self.gradWeight then
  29. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  30. end
  31. if self.gradBias then
  32. self.gradBias:resize(self.nOutputPlane):zero()
  33. end
  34. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  35. end
  36. function srcnn.channels(model)
  37. if model.w2nn_channels ~= nil then
  38. return model.w2nn_channels
  39. else
  40. return model:get(model:size() - 1).weight:size(1)
  41. end
  42. end
  43. function srcnn.backend(model)
  44. local conv = model:findModules("cudnn.SpatialConvolution")
  45. if #conv > 0 then
  46. return "cudnn"
  47. else
  48. return "cunn"
  49. end
  50. end
  51. function srcnn.color(model)
  52. local ch = srcnn.channels(model)
  53. if ch == 3 then
  54. return "rgb"
  55. else
  56. return "y"
  57. end
  58. end
  59. function srcnn.name(model)
  60. if model.w2nn_arch_name then
  61. return model.w2nn_arch_name
  62. else
  63. local conv = model:findModules("nn.SpatialConvolutionMM")
  64. if #conv == 0 then
  65. conv = model:findModules("cudnn.SpatialConvolution")
  66. end
  67. if #conv == 7 then
  68. return "vgg_7"
  69. elseif #conv == 12 then
  70. return "vgg_12"
  71. else
  72. error("unsupported model name")
  73. end
  74. end
  75. end
  76. function srcnn.offset_size(model)
  77. if model.w2nn_offset ~= nil then
  78. return model.w2nn_offset
  79. else
  80. local name = srcnn.name(model)
  81. if name:match("vgg_") then
  82. local conv = model:findModules("nn.SpatialConvolutionMM")
  83. if #conv == 0 then
  84. conv = model:findModules("cudnn.SpatialConvolution")
  85. end
  86. local offset = 0
  87. for i = 1, #conv do
  88. offset = offset + (conv[i].kW - 1) / 2
  89. end
  90. return math.floor(offset)
  91. else
  92. error("unsupported model name")
  93. end
  94. end
  95. end
  96. function srcnn.has_resize(model)
  97. if model.w2nn_resize ~= nil then
  98. return model.w2nn_resize
  99. else
  100. local name = srcnn.name(model)
  101. if name:match("upconv") ~= nil then
  102. return true
  103. else
  104. return false
  105. end
  106. end
  107. end
  108. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  109. if backend == "cunn" then
  110. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  111. elseif backend == "cudnn" then
  112. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  113. else
  114. error("unsupported backend:" .. backend)
  115. end
  116. end
  117. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  118. if backend == "cunn" then
  119. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  120. elseif backend == "cudnn" then
  121. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  122. else
  123. error("unsupported backend:" .. backend)
  124. end
  125. end
  126. -- VGG style net(7 layers)
  127. function srcnn.vgg_7(backend, ch)
  128. local model = nn.Sequential()
  129. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  130. model:add(w2nn.LeakyReLU(0.1))
  131. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  132. model:add(w2nn.LeakyReLU(0.1))
  133. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  134. model:add(w2nn.LeakyReLU(0.1))
  135. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  136. model:add(w2nn.LeakyReLU(0.1))
  137. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  138. model:add(w2nn.LeakyReLU(0.1))
  139. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  140. model:add(w2nn.LeakyReLU(0.1))
  141. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  142. model:add(nn.View(-1):setNumInputDims(3))
  143. model.w2nn_arch_name = "vgg_7"
  144. model.w2nn_offset = 7
  145. model.w2nn_resize = false
  146. model.w2nn_channels = ch
  147. --model:cuda()
  148. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  149. return model
  150. end
  151. -- VGG style net(12 layers)
  152. function srcnn.vgg_12(backend, ch)
  153. local model = nn.Sequential()
  154. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  155. model:add(w2nn.LeakyReLU(0.1))
  156. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  157. model:add(w2nn.LeakyReLU(0.1))
  158. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  159. model:add(w2nn.LeakyReLU(0.1))
  160. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  161. model:add(w2nn.LeakyReLU(0.1))
  162. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  163. model:add(w2nn.LeakyReLU(0.1))
  164. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  165. model:add(w2nn.LeakyReLU(0.1))
  166. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  167. model:add(w2nn.LeakyReLU(0.1))
  168. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  169. model:add(w2nn.LeakyReLU(0.1))
  170. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  171. model:add(w2nn.LeakyReLU(0.1))
  172. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  173. model:add(w2nn.LeakyReLU(0.1))
  174. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  175. model:add(w2nn.LeakyReLU(0.1))
  176. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  177. model:add(nn.View(-1):setNumInputDims(3))
  178. model.w2nn_arch_name = "vgg_12"
  179. model.w2nn_offset = 12
  180. model.w2nn_resize = false
  181. model.w2nn_channels = ch
  182. --model:cuda()
  183. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  184. return model
  185. end
  186. -- Dilated Convolution (7 layers)
  187. function srcnn.dilated_7(backend, ch)
  188. local model = nn.Sequential()
  189. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  190. model:add(w2nn.LeakyReLU(0.1))
  191. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  192. model:add(w2nn.LeakyReLU(0.1))
  193. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  194. model:add(w2nn.LeakyReLU(0.1))
  195. model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  196. model:add(w2nn.LeakyReLU(0.1))
  197. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
  198. model:add(w2nn.LeakyReLU(0.1))
  199. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  200. model:add(w2nn.LeakyReLU(0.1))
  201. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  202. model:add(nn.View(-1):setNumInputDims(3))
  203. model.w2nn_arch_name = "dilated_7"
  204. model.w2nn_offset = 12
  205. model.w2nn_resize = false
  206. model.w2nn_channels = ch
  207. --model:cuda()
  208. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  209. return model
  210. end
  211. -- Up Convolution
  212. function srcnn.upconv_7(backend, ch)
  213. local model = nn.Sequential()
  214. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  215. model:add(w2nn.LeakyReLU(0.1))
  216. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  217. model:add(w2nn.LeakyReLU(0.1))
  218. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  219. model:add(w2nn.LeakyReLU(0.1))
  220. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  221. model:add(w2nn.LeakyReLU(0.1))
  222. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  223. model:add(w2nn.LeakyReLU(0.1))
  224. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  225. model:add(w2nn.LeakyReLU(0.1))
  226. model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 1, 1))
  227. model.w2nn_arch_name = "upconv_7"
  228. model.w2nn_offset = 12
  229. model.w2nn_resize = true
  230. model.w2nn_channels = ch
  231. --model:cuda()
  232. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  233. return model
  234. end
  235. function srcnn.upconv_8_4x(backend, ch)
  236. local model = nn.Sequential()
  237. model:add(SpatialFullConvolution(backend, ch, 32, 4, 4, 2, 2, 1, 1))
  238. model:add(w2nn.LeakyReLU(0.1))
  239. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  240. model:add(w2nn.LeakyReLU(0.1))
  241. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  242. model:add(w2nn.LeakyReLU(0.1))
  243. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  244. model:add(w2nn.LeakyReLU(0.1))
  245. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  246. model:add(w2nn.LeakyReLU(0.1))
  247. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  248. model:add(w2nn.LeakyReLU(0.1))
  249. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  250. model:add(w2nn.LeakyReLU(0.1))
  251. model:add(SpatialFullConvolution(backend, 64, 3, 4, 4, 2, 2, 1, 1))
  252. model.w2nn_arch_name = "upconv_8_4x"
  253. model.w2nn_offset = 12
  254. model.w2nn_resize = true
  255. model.w2nn_channels = ch
  256. --model:cuda()
  257. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  258. return model
  259. end
  260. function srcnn.create(model_name, backend, color)
  261. model_name = model_name or "vgg_7"
  262. backend = backend or "cunn"
  263. color = color or "rgb"
  264. local ch = 3
  265. if color == "rgb" then
  266. ch = 3
  267. elseif color == "y" then
  268. ch = 1
  269. else
  270. error("unsupported color: " .. color)
  271. end
  272. if srcnn[model_name] then
  273. return srcnn[model_name](backend, ch)
  274. else
  275. error("unsupported model_name: " .. model_name)
  276. end
  277. end
  278. --local model = srcnn.upconv_8_4x("cunn", 3):cuda()
  279. --print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
  280. return srcnn