srcnn.lua 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. local function msra_filler(mod)
  6. local fin = mod.kW * mod.kH * mod.nInputPlane
  7. local fout = mod.kW * mod.kH * mod.nOutputPlane
  8. stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
  9. mod.weight:normal(0, stdv)
  10. mod.bias:zero()
  11. end
  12. local function identity_filler(mod)
  13. assert(mod.nInputPlane <= mod.nOutputPlane)
  14. mod.weight:normal(0, 0.01)
  15. mod.bias:zero()
  16. local num_groups = mod.nInputPlane -- fixed
  17. local filler_value = num_groups / mod.nOutputPlane
  18. local in_group_size = math.floor(mod.nInputPlane / num_groups)
  19. local out_group_size = math.floor(mod.nOutputPlane / num_groups)
  20. local x = math.floor(mod.kW / 2)
  21. local y = math.floor(mod.kH / 2)
  22. for i = 0, num_groups - 1 do
  23. for j = i * out_group_size, (i + 1) * out_group_size - 1 do
  24. for k = i * in_group_size, (i + 1) * in_group_size - 1 do
  25. mod.weight[j+1][k+1][y+1][x+1] = filler_value
  26. end
  27. end
  28. end
  29. end
  30. function nn.SpatialConvolutionMM:reset(stdv)
  31. msra_filler(self)
  32. end
  33. function nn.SpatialFullConvolution:reset(stdv)
  34. msra_filler(self)
  35. end
  36. function nn.SpatialDilatedConvolution:reset(stdv)
  37. identity_filler(self)
  38. end
  39. if cudnn and cudnn.SpatialConvolution then
  40. function cudnn.SpatialConvolution:reset(stdv)
  41. msra_filler(self)
  42. end
  43. function cudnn.SpatialFullConvolution:reset(stdv)
  44. msra_filler(self)
  45. end
  46. if cudnn.SpatialDilatedConvolution then
  47. function cudnn.SpatialDilatedConvolution:reset(stdv)
  48. identity_filler(self)
  49. end
  50. end
  51. end
  52. function nn.SpatialConvolutionMM:clearState()
  53. if self.gradWeight then
  54. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  55. end
  56. if self.gradBias then
  57. self.gradBias:resize(self.nOutputPlane):zero()
  58. end
  59. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  60. end
  61. function srcnn.channels(model)
  62. if model.w2nn_channels ~= nil then
  63. return model.w2nn_channels
  64. else
  65. return model:get(model:size() - 1).weight:size(1)
  66. end
  67. end
  68. function srcnn.backend(model)
  69. local conv = model:findModules("cudnn.SpatialConvolution")
  70. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  71. if #conv > 0 or #fullconv > 0 then
  72. return "cudnn"
  73. else
  74. return "cunn"
  75. end
  76. end
  77. function srcnn.color(model)
  78. local ch = srcnn.channels(model)
  79. if ch == 3 then
  80. return "rgb"
  81. else
  82. return "y"
  83. end
  84. end
  85. function srcnn.name(model)
  86. if model.w2nn_arch_name ~= nil then
  87. return model.w2nn_arch_name
  88. else
  89. local conv = model:findModules("nn.SpatialConvolutionMM")
  90. if #conv == 0 then
  91. conv = model:findModules("cudnn.SpatialConvolution")
  92. end
  93. if #conv == 7 then
  94. return "vgg_7"
  95. elseif #conv == 12 then
  96. return "vgg_12"
  97. else
  98. error("unsupported model")
  99. end
  100. end
  101. end
  102. function srcnn.offset_size(model)
  103. if model.w2nn_offset ~= nil then
  104. return model.w2nn_offset
  105. else
  106. local name = srcnn.name(model)
  107. if name:match("vgg_") then
  108. local conv = model:findModules("nn.SpatialConvolutionMM")
  109. if #conv == 0 then
  110. conv = model:findModules("cudnn.SpatialConvolution")
  111. end
  112. local offset = 0
  113. for i = 1, #conv do
  114. offset = offset + (conv[i].kW - 1) / 2
  115. end
  116. return math.floor(offset)
  117. else
  118. error("unsupported model")
  119. end
  120. end
  121. end
  122. function srcnn.scale_factor(model)
  123. if model.w2nn_scale_factor ~= nil then
  124. return model.w2nn_scale_factor
  125. else
  126. local name = srcnn.name(model)
  127. if name == "upconv_7" then
  128. return 2
  129. elseif name == "upconv_8_4x" then
  130. return 4
  131. else
  132. return 1
  133. end
  134. end
  135. end
  136. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  137. if backend == "cunn" then
  138. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  139. elseif backend == "cudnn" then
  140. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  141. else
  142. error("unsupported backend:" .. backend)
  143. end
  144. end
  145. srcnn.SpatialConvolution = SpatialConvolution
  146. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  147. if backend == "cunn" then
  148. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  149. elseif backend == "cudnn" then
  150. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  151. else
  152. error("unsupported backend:" .. backend)
  153. end
  154. end
  155. srcnn.SpatialFullConvolution = SpatialFullConvolution
  156. local function ReLU(backend)
  157. if backend == "cunn" then
  158. return nn.ReLU(true)
  159. elseif backend == "cudnn" then
  160. return cudnn.ReLU(true)
  161. else
  162. error("unsupported backend:" .. backend)
  163. end
  164. end
  165. srcnn.ReLU = ReLU
  166. local function Sigmoid(backend)
  167. if backend == "cunn" then
  168. return nn.Sigmoid(true)
  169. elseif backend == "cudnn" then
  170. return cudnn.Sigmoid(true)
  171. else
  172. error("unsupported backend:" .. backend)
  173. end
  174. end
  175. srcnn.ReLU = ReLU
  176. local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
  177. if backend == "cunn" then
  178. return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  179. elseif backend == "cudnn" then
  180. return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  181. else
  182. error("unsupported backend:" .. backend)
  183. end
  184. end
  185. srcnn.SpatialMaxPooling = SpatialMaxPooling
  186. local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
  187. if backend == "cunn" then
  188. return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  189. elseif backend == "cudnn" then
  190. return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  191. else
  192. error("unsupported backend:" .. backend)
  193. end
  194. end
  195. srcnn.SpatialAveragePooling = SpatialAveragePooling
  196. local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  197. if backend == "cunn" then
  198. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  199. elseif backend == "cudnn" then
  200. if cudnn.SpatialDilatedConvolution then
  201. -- cudnn v 6
  202. return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  203. else
  204. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  205. end
  206. else
  207. error("unsupported backend:" .. backend)
  208. end
  209. end
  210. srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
  211. local function GlobalAveragePooling(n_output)
  212. local gap = nn.Sequential()
  213. gap:add(nn.Mean(-1, -1)):add(nn.Mean(-1, -1))
  214. gap:add(nn.View(-1, n_output, 1, 1))
  215. return gap
  216. end
  217. srcnn.GlobalAveragePooling = GlobalAveragePooling
  218. -- VGG style net(7 layers)
  219. function srcnn.vgg_7(backend, ch)
  220. local model = nn.Sequential()
  221. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  222. model:add(nn.LeakyReLU(0.1, true))
  223. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  224. model:add(nn.LeakyReLU(0.1, true))
  225. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  226. model:add(nn.LeakyReLU(0.1, true))
  227. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  228. model:add(nn.LeakyReLU(0.1, true))
  229. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  230. model:add(nn.LeakyReLU(0.1, true))
  231. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  232. model:add(nn.LeakyReLU(0.1, true))
  233. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  234. model:add(w2nn.InplaceClip01())
  235. model:add(nn.View(-1):setNumInputDims(3))
  236. model.w2nn_arch_name = "vgg_7"
  237. model.w2nn_offset = 7
  238. model.w2nn_scale_factor = 1
  239. model.w2nn_channels = ch
  240. --model:cuda()
  241. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  242. return model
  243. end
  244. -- Upconvolution
  245. function srcnn.upconv_7(backend, ch)
  246. local model = nn.Sequential()
  247. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  248. model:add(nn.LeakyReLU(0.1, true))
  249. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  250. model:add(nn.LeakyReLU(0.1, true))
  251. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  252. model:add(nn.LeakyReLU(0.1, true))
  253. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  254. model:add(nn.LeakyReLU(0.1, true))
  255. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  256. model:add(nn.LeakyReLU(0.1, true))
  257. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  258. model:add(nn.LeakyReLU(0.1, true))
  259. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  260. model:add(w2nn.InplaceClip01())
  261. model:add(nn.View(-1):setNumInputDims(3))
  262. model.w2nn_arch_name = "upconv_7"
  263. model.w2nn_offset = 14
  264. model.w2nn_scale_factor = 2
  265. model.w2nn_resize = true
  266. model.w2nn_channels = ch
  267. return model
  268. end
  269. -- large version of upconv_7
  270. -- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
  271. function srcnn.upconv_7l(backend, ch)
  272. local model = nn.Sequential()
  273. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  274. model:add(nn.LeakyReLU(0.1, true))
  275. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  276. model:add(nn.LeakyReLU(0.1, true))
  277. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  278. model:add(nn.LeakyReLU(0.1, true))
  279. model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
  280. model:add(nn.LeakyReLU(0.1, true))
  281. model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
  282. model:add(nn.LeakyReLU(0.1, true))
  283. model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
  284. model:add(nn.LeakyReLU(0.1, true))
  285. model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
  286. model:add(w2nn.InplaceClip01())
  287. model:add(nn.View(-1):setNumInputDims(3))
  288. model.w2nn_arch_name = "upconv_7l"
  289. model.w2nn_offset = 14
  290. model.w2nn_scale_factor = 2
  291. model.w2nn_resize = true
  292. model.w2nn_channels = ch
  293. --model:cuda()
  294. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  295. return model
  296. end
  297. function srcnn.resnet_14l(backend, ch)
  298. local function resblock(backend, i, o)
  299. local seq = nn.Sequential()
  300. local con = nn.ConcatTable()
  301. local conv = nn.Sequential()
  302. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  303. conv:add(nn.LeakyReLU(0.1, true))
  304. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  305. conv:add(nn.LeakyReLU(0.1, true))
  306. con:add(conv)
  307. if i == o then
  308. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  309. else
  310. local seq = nn.Sequential()
  311. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  312. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  313. con:add(seq)
  314. end
  315. seq:add(con)
  316. seq:add(nn.CAddTable())
  317. return seq
  318. end
  319. local model = nn.Sequential()
  320. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  321. model:add(nn.LeakyReLU(0.1, true))
  322. model:add(resblock(backend, 32, 64))
  323. model:add(resblock(backend, 64, 64))
  324. model:add(resblock(backend, 64, 128))
  325. model:add(resblock(backend, 128, 128))
  326. model:add(resblock(backend, 128, 256))
  327. model:add(resblock(backend, 256, 256))
  328. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  329. model:add(w2nn.InplaceClip01())
  330. model:add(nn.View(-1):setNumInputDims(3))
  331. model.w2nn_arch_name = "resnet_14l"
  332. model.w2nn_offset = 28
  333. model.w2nn_scale_factor = 2
  334. model.w2nn_resize = true
  335. model.w2nn_channels = ch
  336. --model:cuda()
  337. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  338. return model
  339. end
  340. -- for segmentation
  341. function srcnn.fcn_v1(backend, ch)
  342. -- input_size = 120
  343. local model = nn.Sequential()
  344. --i = 120
  345. --model:cuda()
  346. --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
  347. model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
  348. model:add(nn.LeakyReLU(0.1, true))
  349. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  350. model:add(nn.LeakyReLU(0.1, true))
  351. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  352. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  353. model:add(nn.LeakyReLU(0.1, true))
  354. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  355. model:add(nn.LeakyReLU(0.1, true))
  356. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  357. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  358. model:add(nn.LeakyReLU(0.1, true))
  359. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  360. model:add(nn.LeakyReLU(0.1, true))
  361. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  362. model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
  363. model:add(nn.LeakyReLU(0.1, true))
  364. model:add(nn.Dropout(0.5, false, true))
  365. model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
  366. model:add(nn.LeakyReLU(0.1, true))
  367. model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
  368. model:add(nn.LeakyReLU(0.1, true))
  369. model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
  370. model:add(nn.LeakyReLU(0.1, true))
  371. model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
  372. model:add(nn.LeakyReLU(0.1, true))
  373. model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
  374. model:add(nn.LeakyReLU(0.1, true))
  375. model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
  376. model:add(w2nn.InplaceClip01())
  377. model:add(nn.View(-1):setNumInputDims(3))
  378. model.w2nn_arch_name = "fcn_v1"
  379. model.w2nn_offset = 36
  380. model.w2nn_scale_factor = 1
  381. model.w2nn_channels = ch
  382. model.w2nn_input_size = 120
  383. --model.w2nn_gcn = true
  384. return model
  385. end
  386. -- Squeeze and Excitation Block
  387. local function SEBlock(backend, n_output, r)
  388. local con = nn.ConcatTable(2)
  389. local attention = nn.Sequential()
  390. local n_mid = math.floor(n_output / r)
  391. attention:add(GlobalAveragePooling(n_output))
  392. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  393. attention:add(nn.ReLU(true))
  394. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  395. attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
  396. con:add(nn.Identity())
  397. con:add(attention)
  398. return con
  399. end
  400. -- I devised this arch for the block size and global average pooling problem,
  401. -- but SEBlock may possibly learn multi-scale input or just a normalization. No problems occur.
  402. -- So this arch is not used.
  403. local function SpatialSEBlock(backend, ave_size, n_output, r)
  404. local con = nn.ConcatTable(2)
  405. local attention = nn.Sequential()
  406. local n_mid = math.floor(n_output / r)
  407. attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size))
  408. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  409. attention:add(nn.ReLU(true))
  410. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  411. attention:add(nn.Sigmoid(true))
  412. attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size))
  413. con:add(nn.Identity())
  414. con:add(attention)
  415. return con
  416. end
  417. local function unet_branch(backend, insert, backend, n_input, n_output, depad)
  418. local block = nn.Sequential()
  419. local con = nn.ConcatTable(2)
  420. local model = nn.Sequential()
  421. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  422. block:add(nn.LeakyReLU(0.1, true))
  423. block:add(insert)
  424. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  425. block:add(nn.LeakyReLU(0.1, true))
  426. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  427. con:add(block)
  428. model:add(con)
  429. model:add(nn.CAddTable())
  430. return model
  431. end
  432. local function unet_conv(backend, n_input, n_middle, n_output, se)
  433. local model = nn.Sequential()
  434. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  435. model:add(nn.LeakyReLU(0.1, true))
  436. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  437. model:add(nn.LeakyReLU(0.1, true))
  438. if se then
  439. model:add(SEBlock(backend, n_output, 8))
  440. model:add(w2nn.ScaleTable())
  441. end
  442. return model
  443. end
  444. -- Cascaded Residual Channel Attention U-Net
  445. function srcnn.upcunet(backend, ch)
  446. -- Residual U-Net
  447. local function unet(backend, ch, deconv)
  448. local block1 = unet_conv(backend, 128, 256, 128, true)
  449. local block2 = nn.Sequential()
  450. block2:add(unet_conv(backend, 64, 64, 128, true))
  451. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  452. block2:add(unet_conv(backend, 128, 64, 64, true))
  453. local model = nn.Sequential()
  454. model:add(unet_conv(backend, ch, 32, 64, true))
  455. model:add(unet_branch(backend, block2, backend, 64, 64, 16))
  456. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  457. model:add(nn.LeakyReLU(0.1))
  458. if deconv then
  459. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  460. else
  461. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  462. end
  463. return model
  464. end
  465. local model = nn.Sequential()
  466. local con = nn.ConcatTable()
  467. local aux_con = nn.ConcatTable()
  468. -- 2 cascade
  469. model:add(unet(backend, ch, true))
  470. con:add(unet(backend, ch, false))
  471. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  472. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  473. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  474. model:add(con)
  475. model:add(aux_con)
  476. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  477. model.w2nn_arch_name = "upcunet"
  478. model.w2nn_offset = 60
  479. model.w2nn_scale_factor = 2
  480. model.w2nn_channels = ch
  481. model.w2nn_resize = true
  482. model.w2nn_valid_input_size = {}
  483. for i = 76, 512, 4 do
  484. table.insert(model.w2nn_valid_input_size, i)
  485. end
  486. return model
  487. end
  488. -- cunet for 1x
  489. function srcnn.cunet(backend, ch)
  490. local function unet(backend, ch)
  491. local block1 = unet_conv(backend, 128, 256, 128, true)
  492. local block2 = nn.Sequential()
  493. block2:add(unet_conv(backend, 64, 64, 128, true))
  494. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  495. block2:add(unet_conv(backend, 128, 64, 64, true))
  496. local model = nn.Sequential()
  497. model:add(unet_conv(backend, ch, 32, 64, true))
  498. model:add(unet_branch(backend, block2, backend, 64, 64, 16))
  499. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  500. model:add(nn.LeakyReLU(0.1))
  501. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  502. return model
  503. end
  504. local model = nn.Sequential()
  505. local con = nn.ConcatTable()
  506. local aux_con = nn.ConcatTable()
  507. -- 2 cascade
  508. model:add(unet(backend, ch))
  509. con:add(unet(backend, ch))
  510. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  511. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  512. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  513. model:add(con)
  514. model:add(aux_con)
  515. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  516. model.w2nn_arch_name = "cunet"
  517. model.w2nn_offset = 40
  518. model.w2nn_scale_factor = 1
  519. model.w2nn_channels = ch
  520. model.w2nn_resize = false
  521. model.w2nn_valid_input_size = {}
  522. for i = 100, 512, 4 do
  523. table.insert(model.w2nn_valid_input_size, i)
  524. end
  525. return model
  526. end
  527. local function bench()
  528. local sys = require 'sys'
  529. cudnn.benchmark = true
  530. local model = nil
  531. local arch = {"upconv_7", "upcunet","vgg_7", "cunet"}
  532. local backend = "cudnn"
  533. for k = 1, #arch do
  534. model = srcnn[arch[k]](backend, 3):cuda()
  535. model:evaluate()
  536. local dummy = nil
  537. -- warn
  538. for i = 1, 20 do
  539. local x = torch.Tensor(4, 3, 172, 172):uniform():cuda()
  540. model:forward(x)
  541. end
  542. t = sys.clock()
  543. for i = 1, 20 do
  544. local x = torch.Tensor(4, 3, 172, 172):uniform():cuda()
  545. local z = model:forward(x)
  546. if dummy == nil then
  547. dummy = z:clone()
  548. else
  549. dummy:add(z)
  550. end
  551. end
  552. print(arch[k], sys.clock() - t)
  553. model:clearState()
  554. end
  555. end
  556. function srcnn.create(model_name, backend, color)
  557. model_name = model_name or "vgg_7"
  558. backend = backend or "cunn"
  559. color = color or "rgb"
  560. local ch = 3
  561. if color == "rgb" then
  562. ch = 3
  563. elseif color == "y" then
  564. ch = 1
  565. else
  566. error("unsupported color: " .. color)
  567. end
  568. if srcnn[model_name] then
  569. local model = srcnn[model_name](backend, ch)
  570. assert(model.w2nn_offset % model.w2nn_scale_factor == 0)
  571. return model
  572. else
  573. error("unsupported model_name: " .. model_name)
  574. end
  575. end
  576. --[[
  577. local model = srcnn.cunet_v3("cunn", 3):cuda()
  578. print(model)
  579. model:training()
  580. print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()):size())
  581. bench()
  582. os.exit()
  583. --]]
  584. return srcnn