srcnn.lua 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. local function msra_filler(mod)
  6. local fin = mod.kW * mod.kH * mod.nInputPlane
  7. local fout = mod.kW * mod.kH * mod.nOutputPlane
  8. stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
  9. mod.weight:normal(0, stdv)
  10. mod.bias:zero()
  11. end
  12. local function identity_filler(mod)
  13. assert(mod.nInputPlane <= mod.nOutputPlane)
  14. mod.weight:normal(0, 0.01)
  15. mod.bias:zero()
  16. local num_groups = mod.nInputPlane -- fixed
  17. local filler_value = num_groups / mod.nOutputPlane
  18. local in_group_size = math.floor(mod.nInputPlane / num_groups)
  19. local out_group_size = math.floor(mod.nOutputPlane / num_groups)
  20. local x = math.floor(mod.kW / 2)
  21. local y = math.floor(mod.kH / 2)
  22. for i = 0, num_groups - 1 do
  23. for j = i * out_group_size, (i + 1) * out_group_size - 1 do
  24. for k = i * in_group_size, (i + 1) * in_group_size - 1 do
  25. mod.weight[j+1][k+1][y+1][x+1] = filler_value
  26. end
  27. end
  28. end
  29. end
  30. function nn.SpatialConvolutionMM:reset(stdv)
  31. msra_filler(self)
  32. end
  33. function nn.SpatialFullConvolution:reset(stdv)
  34. msra_filler(self)
  35. end
  36. function nn.SpatialDilatedConvolution:reset(stdv)
  37. identity_filler(self)
  38. end
  39. if cudnn and cudnn.SpatialConvolution then
  40. function cudnn.SpatialConvolution:reset(stdv)
  41. msra_filler(self)
  42. end
  43. function cudnn.SpatialFullConvolution:reset(stdv)
  44. msra_filler(self)
  45. end
  46. if cudnn.SpatialDilatedConvolution then
  47. function cudnn.SpatialDilatedConvolution:reset(stdv)
  48. identity_filler(self)
  49. end
  50. end
  51. end
  52. function nn.SpatialConvolutionMM:clearState()
  53. if self.gradWeight then
  54. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  55. end
  56. if self.gradBias then
  57. self.gradBias:resize(self.nOutputPlane):zero()
  58. end
  59. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  60. end
  61. function srcnn.channels(model)
  62. if model.w2nn_channels ~= nil then
  63. return model.w2nn_channels
  64. else
  65. return model:get(model:size() - 1).weight:size(1)
  66. end
  67. end
  68. function srcnn.backend(model)
  69. local conv = model:findModules("cudnn.SpatialConvolution")
  70. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  71. if #conv > 0 or #fullconv > 0 then
  72. return "cudnn"
  73. else
  74. return "cunn"
  75. end
  76. end
  77. function srcnn.color(model)
  78. local ch = srcnn.channels(model)
  79. if ch == 3 then
  80. return "rgb"
  81. else
  82. return "y"
  83. end
  84. end
  85. function srcnn.name(model)
  86. if model.w2nn_arch_name ~= nil then
  87. return model.w2nn_arch_name
  88. else
  89. local conv = model:findModules("nn.SpatialConvolutionMM")
  90. if #conv == 0 then
  91. conv = model:findModules("cudnn.SpatialConvolution")
  92. end
  93. if #conv == 7 then
  94. return "vgg_7"
  95. elseif #conv == 12 then
  96. return "vgg_12"
  97. else
  98. error("unsupported model")
  99. end
  100. end
  101. end
  102. function srcnn.offset_size(model)
  103. if model.w2nn_offset ~= nil then
  104. return model.w2nn_offset
  105. else
  106. local name = srcnn.name(model)
  107. if name:match("vgg_") then
  108. local conv = model:findModules("nn.SpatialConvolutionMM")
  109. if #conv == 0 then
  110. conv = model:findModules("cudnn.SpatialConvolution")
  111. end
  112. local offset = 0
  113. for i = 1, #conv do
  114. offset = offset + (conv[i].kW - 1) / 2
  115. end
  116. return math.floor(offset)
  117. else
  118. error("unsupported model")
  119. end
  120. end
  121. end
  122. function srcnn.scale_factor(model)
  123. if model.w2nn_scale_factor ~= nil then
  124. return model.w2nn_scale_factor
  125. else
  126. local name = srcnn.name(model)
  127. if name == "upconv_7" then
  128. return 2
  129. elseif name == "upconv_8_4x" then
  130. return 4
  131. else
  132. return 1
  133. end
  134. end
  135. end
  136. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  137. if backend == "cunn" then
  138. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  139. elseif backend == "cudnn" then
  140. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  141. else
  142. error("unsupported backend:" .. backend)
  143. end
  144. end
  145. srcnn.SpatialConvolution = SpatialConvolution
  146. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  147. if backend == "cunn" then
  148. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  149. elseif backend == "cudnn" then
  150. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  151. else
  152. error("unsupported backend:" .. backend)
  153. end
  154. end
  155. srcnn.SpatialFullConvolution = SpatialFullConvolution
  156. local function ReLU(backend)
  157. if backend == "cunn" then
  158. return nn.ReLU(true)
  159. elseif backend == "cudnn" then
  160. return cudnn.ReLU(true)
  161. else
  162. error("unsupported backend:" .. backend)
  163. end
  164. end
  165. srcnn.ReLU = ReLU
  166. local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
  167. if backend == "cunn" then
  168. return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  169. elseif backend == "cudnn" then
  170. return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  171. else
  172. error("unsupported backend:" .. backend)
  173. end
  174. end
  175. srcnn.SpatialMaxPooling = SpatialMaxPooling
  176. local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
  177. if backend == "cunn" then
  178. return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  179. elseif backend == "cudnn" then
  180. return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  181. else
  182. error("unsupported backend:" .. backend)
  183. end
  184. end
  185. srcnn.SpatialAveragePooling = SpatialAveragePooling
  186. local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  187. if backend == "cunn" then
  188. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  189. elseif backend == "cudnn" then
  190. if cudnn.SpatialDilatedConvolution then
  191. -- cudnn v 6
  192. return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  193. else
  194. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  195. end
  196. else
  197. error("unsupported backend:" .. backend)
  198. end
  199. end
  200. srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
  201. -- VGG style net(7 layers)
  202. function srcnn.vgg_7(backend, ch)
  203. local model = nn.Sequential()
  204. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  205. model:add(nn.LeakyReLU(0.1, true))
  206. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  207. model:add(nn.LeakyReLU(0.1, true))
  208. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  209. model:add(nn.LeakyReLU(0.1, true))
  210. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  211. model:add(nn.LeakyReLU(0.1, true))
  212. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  213. model:add(nn.LeakyReLU(0.1, true))
  214. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  215. model:add(nn.LeakyReLU(0.1, true))
  216. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  217. model:add(w2nn.InplaceClip01())
  218. model:add(nn.View(-1):setNumInputDims(3))
  219. model.w2nn_arch_name = "vgg_7"
  220. model.w2nn_offset = 7
  221. model.w2nn_scale_factor = 1
  222. model.w2nn_channels = ch
  223. --model:cuda()
  224. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  225. return model
  226. end
  227. -- VGG style net(12 layers)
  228. function srcnn.vgg_12(backend, ch)
  229. local model = nn.Sequential()
  230. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  231. model:add(nn.LeakyReLU(0.1, true))
  232. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  233. model:add(nn.LeakyReLU(0.1, true))
  234. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  235. model:add(nn.LeakyReLU(0.1, true))
  236. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  237. model:add(nn.LeakyReLU(0.1, true))
  238. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  239. model:add(nn.LeakyReLU(0.1, true))
  240. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  241. model:add(nn.LeakyReLU(0.1, true))
  242. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  243. model:add(nn.LeakyReLU(0.1, true))
  244. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  245. model:add(nn.LeakyReLU(0.1, true))
  246. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  247. model:add(nn.LeakyReLU(0.1, true))
  248. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  249. model:add(nn.LeakyReLU(0.1, true))
  250. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  251. model:add(nn.LeakyReLU(0.1, true))
  252. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  253. model:add(w2nn.InplaceClip01())
  254. model:add(nn.View(-1):setNumInputDims(3))
  255. model.w2nn_arch_name = "vgg_12"
  256. model.w2nn_offset = 12
  257. model.w2nn_scale_factor = 1
  258. model.w2nn_resize = false
  259. model.w2nn_channels = ch
  260. --model:cuda()
  261. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  262. return model
  263. end
  264. -- Dilated Convolution (7 layers)
  265. function srcnn.dilated_7(backend, ch)
  266. local model = nn.Sequential()
  267. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  268. model:add(nn.LeakyReLU(0.1, true))
  269. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  270. model:add(nn.LeakyReLU(0.1, true))
  271. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  272. model:add(nn.LeakyReLU(0.1, true))
  273. model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  274. model:add(nn.LeakyReLU(0.1, true))
  275. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
  276. model:add(nn.LeakyReLU(0.1, true))
  277. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  278. model:add(nn.LeakyReLU(0.1, true))
  279. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  280. model:add(w2nn.InplaceClip01())
  281. model:add(nn.View(-1):setNumInputDims(3))
  282. model.w2nn_arch_name = "dilated_7"
  283. model.w2nn_offset = 12
  284. model.w2nn_scale_factor = 1
  285. model.w2nn_resize = false
  286. model.w2nn_channels = ch
  287. --model:cuda()
  288. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  289. return model
  290. end
  291. -- Upconvolution
  292. function srcnn.upconv_7(backend, ch)
  293. local model = nn.Sequential()
  294. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  295. model:add(nn.LeakyReLU(0.1, true))
  296. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  297. model:add(nn.LeakyReLU(0.1, true))
  298. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  299. model:add(nn.LeakyReLU(0.1, true))
  300. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  301. model:add(nn.LeakyReLU(0.1, true))
  302. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  303. model:add(nn.LeakyReLU(0.1, true))
  304. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  305. model:add(nn.LeakyReLU(0.1, true))
  306. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  307. model:add(w2nn.InplaceClip01())
  308. model:add(nn.View(-1):setNumInputDims(3))
  309. model.w2nn_arch_name = "upconv_7"
  310. model.w2nn_offset = 14
  311. model.w2nn_scale_factor = 2
  312. model.w2nn_resize = true
  313. model.w2nn_channels = ch
  314. return model
  315. end
  316. -- large version of upconv_7
  317. -- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
  318. function srcnn.upconv_7l(backend, ch)
  319. local model = nn.Sequential()
  320. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  321. model:add(nn.LeakyReLU(0.1, true))
  322. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  323. model:add(nn.LeakyReLU(0.1, true))
  324. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  325. model:add(nn.LeakyReLU(0.1, true))
  326. model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
  327. model:add(nn.LeakyReLU(0.1, true))
  328. model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
  329. model:add(nn.LeakyReLU(0.1, true))
  330. model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
  331. model:add(nn.LeakyReLU(0.1, true))
  332. model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
  333. model:add(w2nn.InplaceClip01())
  334. model:add(nn.View(-1):setNumInputDims(3))
  335. model.w2nn_arch_name = "upconv_7l"
  336. model.w2nn_offset = 14
  337. model.w2nn_scale_factor = 2
  338. model.w2nn_resize = true
  339. model.w2nn_channels = ch
  340. --model:cuda()
  341. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  342. return model
  343. end
  344. -- layerwise linear blending with skip connections
  345. -- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l
  346. function srcnn.skiplb_7(backend, ch)
  347. local function skip(backend, i, o)
  348. local con = nn.Concat(2)
  349. local conv = nn.Sequential()
  350. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1))
  351. conv:add(nn.LeakyReLU(0.1, true))
  352. -- depth concat
  353. con:add(conv)
  354. con:add(nn.Identity()) -- skip
  355. return con
  356. end
  357. local model = nn.Sequential()
  358. model:add(skip(backend, ch, 16))
  359. model:add(skip(backend, 16+ch, 32))
  360. model:add(skip(backend, 32+16+ch, 64))
  361. model:add(skip(backend, 64+32+16+ch, 128))
  362. model:add(skip(backend, 128+64+32+16+ch, 128))
  363. model:add(skip(backend, 128+128+64+32+16+ch, 256))
  364. -- input of last layer = [all layerwise output(contains input layer)].flatten
  365. model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend
  366. model:add(w2nn.InplaceClip01())
  367. model:add(nn.View(-1):setNumInputDims(3))
  368. model.w2nn_arch_name = "skiplb_7"
  369. model.w2nn_offset = 14
  370. model.w2nn_scale_factor = 2
  371. model.w2nn_resize = true
  372. model.w2nn_channels = ch
  373. --model:cuda()
  374. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  375. return model
  376. end
  377. -- dilated convolution + deconvolution
  378. -- Note: This model is not better than upconv_7. Maybe becuase of under-fitting.
  379. function srcnn.dilated_upconv_7(backend, ch)
  380. local model = nn.Sequential()
  381. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  382. model:add(nn.LeakyReLU(0.1, true))
  383. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  384. model:add(nn.LeakyReLU(0.1, true))
  385. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  386. model:add(nn.LeakyReLU(0.1, true))
  387. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  388. model:add(nn.LeakyReLU(0.1, true))
  389. model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  390. model:add(nn.LeakyReLU(0.1, true))
  391. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  392. model:add(nn.LeakyReLU(0.1, true))
  393. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  394. model:add(w2nn.InplaceClip01())
  395. model:add(nn.View(-1):setNumInputDims(3))
  396. model.w2nn_arch_name = "dilated_upconv_7"
  397. model.w2nn_offset = 20
  398. model.w2nn_scale_factor = 2
  399. model.w2nn_resize = true
  400. model.w2nn_channels = ch
  401. --model:cuda()
  402. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  403. return model
  404. end
  405. -- ref: https://arxiv.org/abs/1609.04802
  406. -- note: no batch-norm, no zero-paading
  407. function srcnn.srresnet_2x(backend, ch)
  408. local function resblock(backend)
  409. local seq = nn.Sequential()
  410. local con = nn.ConcatTable()
  411. local conv = nn.Sequential()
  412. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  413. conv:add(ReLU(backend))
  414. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  415. conv:add(ReLU(backend))
  416. con:add(conv)
  417. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  418. seq:add(con)
  419. seq:add(nn.CAddTable())
  420. return seq
  421. end
  422. local model = nn.Sequential()
  423. --model:add(skip(backend, ch, 64 - ch))
  424. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  425. model:add(nn.LeakyReLU(0.1, true))
  426. model:add(resblock(backend))
  427. model:add(resblock(backend))
  428. model:add(resblock(backend))
  429. model:add(resblock(backend))
  430. model:add(resblock(backend))
  431. model:add(resblock(backend))
  432. model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2))
  433. model:add(ReLU(backend))
  434. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  435. model:add(w2nn.InplaceClip01())
  436. --model:add(nn.View(-1):setNumInputDims(3))
  437. model.w2nn_arch_name = "srresnet_2x"
  438. model.w2nn_offset = 28
  439. model.w2nn_scale_factor = 2
  440. model.w2nn_resize = true
  441. model.w2nn_channels = ch
  442. --model:cuda()
  443. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  444. return model
  445. end
  446. -- large version of srresnet_2x. It's current best model but slow.
  447. function srcnn.resnet_14l(backend, ch)
  448. local function resblock(backend, i, o)
  449. local seq = nn.Sequential()
  450. local con = nn.ConcatTable()
  451. local conv = nn.Sequential()
  452. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  453. conv:add(nn.LeakyReLU(0.1, true))
  454. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  455. conv:add(nn.LeakyReLU(0.1, true))
  456. con:add(conv)
  457. if i == o then
  458. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  459. else
  460. local seq = nn.Sequential()
  461. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  462. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  463. con:add(seq)
  464. end
  465. seq:add(con)
  466. seq:add(nn.CAddTable())
  467. return seq
  468. end
  469. local model = nn.Sequential()
  470. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  471. model:add(nn.LeakyReLU(0.1, true))
  472. model:add(resblock(backend, 32, 64))
  473. model:add(resblock(backend, 64, 64))
  474. model:add(resblock(backend, 64, 128))
  475. model:add(resblock(backend, 128, 128))
  476. model:add(resblock(backend, 128, 256))
  477. model:add(resblock(backend, 256, 256))
  478. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  479. model:add(w2nn.InplaceClip01())
  480. model:add(nn.View(-1):setNumInputDims(3))
  481. model.w2nn_arch_name = "resnet_14l"
  482. model.w2nn_offset = 28
  483. model.w2nn_scale_factor = 2
  484. model.w2nn_resize = true
  485. model.w2nn_channels = ch
  486. --model:cuda()
  487. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  488. return model
  489. end
  490. -- for segmentation
  491. function srcnn.fcn_v1(backend, ch)
  492. -- input_size = 120
  493. local model = nn.Sequential()
  494. --i = 120
  495. --model:cuda()
  496. --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
  497. model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
  498. model:add(nn.LeakyReLU(0.1, true))
  499. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  500. model:add(nn.LeakyReLU(0.1, true))
  501. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  502. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  503. model:add(nn.LeakyReLU(0.1, true))
  504. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  505. model:add(nn.LeakyReLU(0.1, true))
  506. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  507. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  508. model:add(nn.LeakyReLU(0.1, true))
  509. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  510. model:add(nn.LeakyReLU(0.1, true))
  511. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  512. model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
  513. model:add(nn.LeakyReLU(0.1, true))
  514. model:add(nn.Dropout(0.5, false, true))
  515. model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
  516. model:add(nn.LeakyReLU(0.1, true))
  517. model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
  518. model:add(nn.LeakyReLU(0.1, true))
  519. model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
  520. model:add(nn.LeakyReLU(0.1, true))
  521. model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
  522. model:add(nn.LeakyReLU(0.1, true))
  523. model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
  524. model:add(nn.LeakyReLU(0.1, true))
  525. model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
  526. model:add(w2nn.InplaceClip01())
  527. model:add(nn.View(-1):setNumInputDims(3))
  528. model.w2nn_arch_name = "fcn_v1"
  529. model.w2nn_offset = 36
  530. model.w2nn_scale_factor = 1
  531. model.w2nn_channels = ch
  532. model.w2nn_input_size = 120
  533. --model.w2nn_gcn = true
  534. return model
  535. end
  536. function srcnn.create(model_name, backend, color)
  537. model_name = model_name or "vgg_7"
  538. backend = backend or "cunn"
  539. color = color or "rgb"
  540. local ch = 3
  541. if color == "rgb" then
  542. ch = 3
  543. elseif color == "y" then
  544. ch = 1
  545. else
  546. error("unsupported color: " .. color)
  547. end
  548. if srcnn[model_name] then
  549. local model = srcnn[model_name](backend, ch)
  550. assert(model.w2nn_offset % model.w2nn_scale_factor == 0)
  551. return model
  552. else
  553. error("unsupported model_name: " .. model_name)
  554. end
  555. end
  556. --[[
  557. local model = srcnn.fcn_v1("cunn", 3):cuda()
  558. print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
  559. print(model)
  560. --]]
  561. return srcnn