srcnn.lua 11 KB


  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. function nn.SpatialConvolutionMM:reset(stdv)
  6. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  7. self.weight:normal(0, stdv)
  8. self.bias:zero()
  9. end
  10. function nn.SpatialFullConvolution:reset(stdv)
  11. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  12. self.weight:normal(0, stdv)
  13. self.bias:zero()
  14. end
  15. if cudnn and cudnn.SpatialConvolution then
  16. function cudnn.SpatialConvolution:reset(stdv)
  17. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  18. self.weight:normal(0, stdv)
  19. self.bias:zero()
  20. end
  21. function cudnn.SpatialFullConvolution:reset(stdv)
  22. stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
  23. self.weight:normal(0, stdv)
  24. self.bias:zero()
  25. end
  26. end
  27. function nn.SpatialConvolutionMM:clearState()
  28. if self.gradWeight then
  29. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  30. end
  31. if self.gradBias then
  32. self.gradBias:resize(self.nOutputPlane):zero()
  33. end
  34. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  35. end
  36. function srcnn.channels(model)
  37. if model.w2nn_channels ~= nil then
  38. return model.w2nn_channels
  39. else
  40. return model:get(model:size() - 1).weight:size(1)
  41. end
  42. end
  43. function srcnn.backend(model)
  44. local conv = model:findModules("cudnn.SpatialConvolution")
  45. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  46. if #conv > 0 or #fullconv > 0 then
  47. return "cudnn"
  48. else
  49. return "cunn"
  50. end
  51. end
  52. function srcnn.color(model)
  53. local ch = srcnn.channels(model)
  54. if ch == 3 then
  55. return "rgb"
  56. else
  57. return "y"
  58. end
  59. end
  60. function srcnn.name(model)
  61. if model.w2nn_arch_name ~= nil then
  62. return model.w2nn_arch_name
  63. else
  64. local conv = model:findModules("nn.SpatialConvolutionMM")
  65. if #conv == 0 then
  66. conv = model:findModules("cudnn.SpatialConvolution")
  67. end
  68. if #conv == 7 then
  69. return "vgg_7"
  70. elseif #conv == 12 then
  71. return "vgg_12"
  72. else
  73. error("unsupported model")
  74. end
  75. end
  76. end
  77. function srcnn.offset_size(model)
  78. if model.w2nn_offset ~= nil then
  79. return model.w2nn_offset
  80. else
  81. local name = srcnn.name(model)
  82. if name:match("vgg_") then
  83. local conv = model:findModules("nn.SpatialConvolutionMM")
  84. if #conv == 0 then
  85. conv = model:findModules("cudnn.SpatialConvolution")
  86. end
  87. local offset = 0
  88. for i = 1, #conv do
  89. offset = offset + (conv[i].kW - 1) / 2
  90. end
  91. return math.floor(offset)
  92. else
  93. error("unsupported model")
  94. end
  95. end
  96. end
  97. function srcnn.scale_factor(model)
  98. if model.w2nn_scale_factor ~= nil then
  99. return model.w2nn_scale_factor
  100. else
  101. local name = srcnn.name(model)
  102. if name == "upconv_7" then
  103. return 2
  104. elseif name == "upconv_8_4x" then
  105. return 4
  106. else
  107. return 1
  108. end
  109. end
  110. end
  111. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  112. if backend == "cunn" then
  113. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  114. elseif backend == "cudnn" then
  115. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  116. else
  117. error("unsupported backend:" .. backend)
  118. end
  119. end
  120. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  121. if backend == "cunn" then
  122. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  123. elseif backend == "cudnn" then
  124. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  125. else
  126. error("unsupported backend:" .. backend)
  127. end
  128. end
  129. -- VGG style net(7 layers)
  130. function srcnn.vgg_7(backend, ch)
  131. local model = nn.Sequential()
  132. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  133. model:add(nn.LeakyReLU(0.1, true))
  134. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  135. model:add(nn.LeakyReLU(0.1, true))
  136. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  137. model:add(nn.LeakyReLU(0.1, true))
  138. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  139. model:add(nn.LeakyReLU(0.1, true))
  140. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  141. model:add(nn.LeakyReLU(0.1, true))
  142. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  143. model:add(nn.LeakyReLU(0.1, true))
  144. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  145. model:add(nn.View(-1):setNumInputDims(3))
  146. model.w2nn_arch_name = "vgg_7"
  147. model.w2nn_offset = 7
  148. model.w2nn_scale_factor = 1
  149. model.w2nn_channels = ch
  150. --model:cuda()
  151. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  152. return model
  153. end
  154. -- VGG style net(12 layers)
  155. function srcnn.vgg_12(backend, ch)
  156. local model = nn.Sequential()
  157. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  158. model:add(nn.LeakyReLU(0.1, true))
  159. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  160. model:add(nn.LeakyReLU(0.1, true))
  161. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  162. model:add(nn.LeakyReLU(0.1, true))
  163. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  164. model:add(nn.LeakyReLU(0.1, true))
  165. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  166. model:add(nn.LeakyReLU(0.1, true))
  167. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  168. model:add(nn.LeakyReLU(0.1, true))
  169. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  170. model:add(nn.LeakyReLU(0.1, true))
  171. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  172. model:add(nn.LeakyReLU(0.1, true))
  173. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  174. model:add(nn.LeakyReLU(0.1, true))
  175. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  176. model:add(nn.LeakyReLU(0.1, true))
  177. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  178. model:add(nn.LeakyReLU(0.1, true))
  179. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  180. model:add(nn.View(-1):setNumInputDims(3))
  181. model.w2nn_arch_name = "vgg_12"
  182. model.w2nn_offset = 12
  183. model.w2nn_scale_factor = 1
  184. model.w2nn_resize = false
  185. model.w2nn_channels = ch
  186. --model:cuda()
  187. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  188. return model
  189. end
  190. -- Dilated Convolution (7 layers)
  191. function srcnn.dilated_7(backend, ch)
  192. local model = nn.Sequential()
  193. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  194. model:add(nn.LeakyReLU(0.1, true))
  195. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  196. model:add(nn.LeakyReLU(0.1, true))
  197. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  198. model:add(nn.LeakyReLU(0.1, true))
  199. model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  200. model:add(nn.LeakyReLU(0.1, true))
  201. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
  202. model:add(nn.LeakyReLU(0.1, true))
  203. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  204. model:add(nn.LeakyReLU(0.1, true))
  205. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  206. model:add(nn.View(-1):setNumInputDims(3))
  207. model.w2nn_arch_name = "dilated_7"
  208. model.w2nn_offset = 12
  209. model.w2nn_scale_factor = 1
  210. model.w2nn_resize = false
  211. model.w2nn_channels = ch
  212. --model:cuda()
  213. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  214. return model
  215. end
  216. -- Up Convolution
  217. function srcnn.upconv_7(backend, ch)
  218. local model = nn.Sequential()
  219. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  220. model:add(nn.LeakyReLU(0.1, true))
  221. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  222. model:add(nn.LeakyReLU(0.1, true))
  223. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  224. model:add(nn.LeakyReLU(0.1, true))
  225. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  226. model:add(nn.LeakyReLU(0.1, true))
  227. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  228. model:add(nn.LeakyReLU(0.1, true))
  229. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  230. model:add(nn.LeakyReLU(0.1, true))
  231. model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 1, 1))
  232. model.w2nn_arch_name = "upconv_7"
  233. model.w2nn_offset = 12
  234. model.w2nn_scale_factor = 2
  235. model.w2nn_resize = true
  236. model.w2nn_channels = ch
  237. --model:cuda()
  238. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  239. return model
  240. end
  241. function srcnn.upconv_8_4x(backend, ch)
  242. local model = nn.Sequential()
  243. model:add(SpatialFullConvolution(backend, ch, 32, 4, 4, 2, 2, 1, 1))
  244. model:add(nn.LeakyReLU(0.1, true))
  245. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  246. model:add(nn.LeakyReLU(0.1, true))
  247. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  248. model:add(nn.LeakyReLU(0.1, true))
  249. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  250. model:add(nn.LeakyReLU(0.1, true))
  251. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  252. model:add(nn.LeakyReLU(0.1, true))
  253. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  254. model:add(nn.LeakyReLU(0.1, true))
  255. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  256. model:add(nn.LeakyReLU(0.1, true))
  257. model:add(SpatialFullConvolution(backend, 64, 3, 4, 4, 2, 2, 1, 1))
  258. model.w2nn_arch_name = "upconv_8_4x"
  259. model.w2nn_offset = 12
  260. model.w2nn_scale_factor = 4
  261. model.w2nn_resize = true
  262. model.w2nn_channels = ch
  263. --model:cuda()
  264. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  265. return model
  266. end
  267. function srcnn.create(model_name, backend, color)
  268. model_name = model_name or "vgg_7"
  269. backend = backend or "cunn"
  270. color = color or "rgb"
  271. local ch = 3
  272. if color == "rgb" then
  273. ch = 3
  274. elseif color == "y" then
  275. ch = 1
  276. else
  277. error("unsupported color: " .. color)
  278. end
  279. if srcnn[model_name] then
  280. local model = srcnn[model_name](backend, ch)
  281. assert(model.w2nn_offset == (model.w2nn_offset / model.w2nn_scale_factor) * model.w2nn_scale_factor)
  282. return model
  283. else
  284. error("unsupported model_name: " .. model_name)
  285. end
  286. end
  287. --local model = srcnn.upconv_8_4x("cunn", 3):cuda()
  288. --print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
  289. return srcnn