srcnn.lua 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. local function msra_filler(mod)
  6. local fin = mod.kW * mod.kH * mod.nInputPlane
  7. local fout = mod.kW * mod.kH * mod.nOutputPlane
  8. stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
  9. mod.weight:normal(0, stdv)
  10. mod.bias:zero()
  11. end
  12. local function identity_filler(mod)
  13. assert(mod.nInputPlane <= mod.nOutputPlane)
  14. mod.weight:normal(0, 0.01)
  15. mod.bias:zero()
  16. local num_groups = mod.nInputPlane -- fixed
  17. local filler_value = num_groups / mod.nOutputPlane
  18. local in_group_size = math.floor(mod.nInputPlane / num_groups)
  19. local out_group_size = math.floor(mod.nOutputPlane / num_groups)
  20. local x = math.floor(mod.kW / 2)
  21. local y = math.floor(mod.kH / 2)
  22. for i = 0, num_groups - 1 do
  23. for j = i * out_group_size, (i + 1) * out_group_size - 1 do
  24. for k = i * in_group_size, (i + 1) * in_group_size - 1 do
  25. mod.weight[j+1][k+1][y+1][x+1] = filler_value
  26. end
  27. end
  28. end
  29. end
  30. function nn.SpatialConvolutionMM:reset(stdv)
  31. msra_filler(self)
  32. end
  33. function nn.SpatialFullConvolution:reset(stdv)
  34. msra_filler(self)
  35. end
  36. function nn.SpatialDilatedConvolution:reset(stdv)
  37. identity_filler(self)
  38. end
  39. if cudnn and cudnn.SpatialConvolution then
  40. function cudnn.SpatialConvolution:reset(stdv)
  41. msra_filler(self)
  42. end
  43. function cudnn.SpatialFullConvolution:reset(stdv)
  44. msra_filler(self)
  45. end
  46. if cudnn.SpatialDilatedConvolution then
  47. function cudnn.SpatialDilatedConvolution:reset(stdv)
  48. identity_filler(self)
  49. end
  50. end
  51. end
  52. function nn.SpatialConvolutionMM:clearState()
  53. if self.gradWeight then
  54. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  55. end
  56. if self.gradBias then
  57. self.gradBias:resize(self.nOutputPlane):zero()
  58. end
  59. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  60. end
  61. function srcnn.channels(model)
  62. if model.w2nn_channels ~= nil then
  63. return model.w2nn_channels
  64. else
  65. return model:get(model:size() - 1).weight:size(1)
  66. end
  67. end
  68. function srcnn.backend(model)
  69. local conv = model:findModules("cudnn.SpatialConvolution")
  70. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  71. if #conv > 0 or #fullconv > 0 then
  72. return "cudnn"
  73. else
  74. return "cunn"
  75. end
  76. end
  77. function srcnn.color(model)
  78. local ch = srcnn.channels(model)
  79. if ch == 3 then
  80. return "rgb"
  81. else
  82. return "y"
  83. end
  84. end
  85. function srcnn.name(model)
  86. if model.w2nn_arch_name ~= nil then
  87. return model.w2nn_arch_name
  88. else
  89. local conv = model:findModules("nn.SpatialConvolutionMM")
  90. if #conv == 0 then
  91. conv = model:findModules("cudnn.SpatialConvolution")
  92. end
  93. if #conv == 7 then
  94. return "vgg_7"
  95. elseif #conv == 12 then
  96. return "vgg_12"
  97. else
  98. error("unsupported model")
  99. end
  100. end
  101. end
  102. function srcnn.offset_size(model)
  103. if model.w2nn_offset ~= nil then
  104. return model.w2nn_offset
  105. else
  106. local name = srcnn.name(model)
  107. if name:match("vgg_") then
  108. local conv = model:findModules("nn.SpatialConvolutionMM")
  109. if #conv == 0 then
  110. conv = model:findModules("cudnn.SpatialConvolution")
  111. end
  112. local offset = 0
  113. for i = 1, #conv do
  114. offset = offset + (conv[i].kW - 1) / 2
  115. end
  116. return math.floor(offset)
  117. else
  118. error("unsupported model")
  119. end
  120. end
  121. end
  122. function srcnn.scale_factor(model)
  123. if model.w2nn_scale_factor ~= nil then
  124. return model.w2nn_scale_factor
  125. else
  126. local name = srcnn.name(model)
  127. if name == "upconv_7" then
  128. return 2
  129. elseif name == "upconv_8_4x" then
  130. return 4
  131. else
  132. return 1
  133. end
  134. end
  135. end
  136. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  137. if backend == "cunn" then
  138. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  139. elseif backend == "cudnn" then
  140. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  141. else
  142. error("unsupported backend:" .. backend)
  143. end
  144. end
  145. srcnn.SpatialConvolution = SpatialConvolution
  146. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  147. if backend == "cunn" then
  148. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  149. elseif backend == "cudnn" then
  150. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  151. else
  152. error("unsupported backend:" .. backend)
  153. end
  154. end
  155. srcnn.SpatialFullConvolution = SpatialFullConvolution
  156. local function ReLU(backend)
  157. if backend == "cunn" then
  158. return nn.ReLU(true)
  159. elseif backend == "cudnn" then
  160. return cudnn.ReLU(true)
  161. else
  162. error("unsupported backend:" .. backend)
  163. end
  164. end
  165. srcnn.ReLU = ReLU
  166. local function Sigmoid(backend)
  167. if backend == "cunn" then
  168. return nn.Sigmoid(true)
  169. elseif backend == "cudnn" then
  170. return cudnn.Sigmoid(true)
  171. else
  172. error("unsupported backend:" .. backend)
  173. end
  174. end
  175. srcnn.ReLU = ReLU
  176. local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
  177. if backend == "cunn" then
  178. return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  179. elseif backend == "cudnn" then
  180. return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  181. else
  182. error("unsupported backend:" .. backend)
  183. end
  184. end
  185. srcnn.SpatialMaxPooling = SpatialMaxPooling
  186. local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
  187. if backend == "cunn" then
  188. return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  189. elseif backend == "cudnn" then
  190. return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  191. else
  192. error("unsupported backend:" .. backend)
  193. end
  194. end
  195. srcnn.SpatialAveragePooling = SpatialAveragePooling
  196. local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  197. if backend == "cunn" then
  198. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  199. elseif backend == "cudnn" then
  200. if cudnn.SpatialDilatedConvolution then
  201. -- cudnn v 6
  202. return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  203. else
  204. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  205. end
  206. else
  207. error("unsupported backend:" .. backend)
  208. end
  209. end
  210. srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
  211. local function GlobalAveragePooling(n_output)
  212. local gap = nn.Sequential()
  213. gap:add(nn.Mean(-1, -1)):add(nn.Mean(-1, -1))
  214. gap:add(nn.View(-1, n_output, 1, 1))
  215. return gap
  216. end
  217. srcnn.GlobalAveragePooling = GlobalAveragePooling
  218. -- Squeeze and Excitation Block
  219. local function SEBlock(backend, n_output, r)
  220. local con = nn.ConcatTable(2)
  221. local attention = nn.Sequential()
  222. local n_mid = math.floor(n_output / r)
  223. attention:add(GlobalAveragePooling(n_output))
  224. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  225. attention:add(nn.ReLU(true))
  226. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  227. attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
  228. con:add(nn.Identity())
  229. con:add(attention)
  230. return con
  231. end
  232. -- I devised this arch for the block size and global average pooling problem,
  233. -- but SEBlock may possibly learn multi-scale input or just a normalization. No problems occur.
  234. -- So this arch is not used.
  235. local function SpatialSEBlock(backend, ave_size, n_output, r)
  236. local con = nn.ConcatTable(2)
  237. local attention = nn.Sequential()
  238. local n_mid = math.floor(n_output / r)
  239. attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size))
  240. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  241. attention:add(nn.ReLU(true))
  242. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  243. attention:add(nn.Sigmoid(true))
  244. attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size))
  245. con:add(nn.Identity())
  246. con:add(attention)
  247. return con
  248. end
  249. local function ResBlock(backend, i, o)
  250. local seq = nn.Sequential()
  251. local con = nn.ConcatTable()
  252. local conv = nn.Sequential()
  253. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  254. conv:add(nn.LeakyReLU(0.1, true))
  255. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  256. conv:add(nn.LeakyReLU(0.1, true))
  257. con:add(conv)
  258. if i == o then
  259. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  260. else
  261. local seq = nn.Sequential()
  262. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  263. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  264. con:add(seq)
  265. end
  266. seq:add(con)
  267. seq:add(nn.CAddTable())
  268. return seq
  269. end
  270. local function ResBlockSE(backend, i, o)
  271. local seq = nn.Sequential()
  272. local con = nn.ConcatTable()
  273. local conv = nn.Sequential()
  274. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  275. conv:add(nn.LeakyReLU(0.1, true))
  276. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  277. conv:add(nn.LeakyReLU(0.1, true))
  278. conv:add(SEBlock(backend, o, 8))
  279. conv:add(w2nn.ScaleTable())
  280. con:add(conv)
  281. if i == o then
  282. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  283. else
  284. local seq = nn.Sequential()
  285. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  286. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  287. con:add(seq)
  288. end
  289. seq:add(con)
  290. seq:add(nn.CAddTable())
  291. return seq
  292. end
  293. local function ResGroup(backend, n, n_output)
  294. local seq = nn.Sequential()
  295. local res = nn.Sequential()
  296. local con = nn.ConcatTable(2)
  297. local depad = -2 * n
  298. for i = 1, n do
  299. res:add(ResBlock(backend, n_output, n_output))
  300. end
  301. con:add(res)
  302. con:add(nn.SpatialZeroPadding(depad, depad, depad, depad))
  303. seq:add(con)
  304. seq:add(nn.CAddTable())
  305. return seq
  306. end
  307. local function ResGroupSE(backend, n, n_output)
  308. local seq = nn.Sequential()
  309. local res = nn.Sequential()
  310. local con = nn.ConcatTable(2)
  311. local depad = -2 * n
  312. for i = 1, n do
  313. res:add(ResBlockSE(backend, n_output, n_output))
  314. end
  315. con:add(res)
  316. con:add(nn.SpatialZeroPadding(depad, depad, depad, depad))
  317. seq:add(con)
  318. seq:add(nn.CAddTable())
  319. return seq
  320. end
  321. -- VGG style net(7 layers)
  322. function srcnn.vgg_7(backend, ch)
  323. local model = nn.Sequential()
  324. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  325. model:add(nn.LeakyReLU(0.1, true))
  326. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  327. model:add(nn.LeakyReLU(0.1, true))
  328. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  329. model:add(nn.LeakyReLU(0.1, true))
  330. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  331. model:add(nn.LeakyReLU(0.1, true))
  332. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  333. model:add(nn.LeakyReLU(0.1, true))
  334. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  335. model:add(nn.LeakyReLU(0.1, true))
  336. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  337. model:add(w2nn.InplaceClip01())
  338. model:add(nn.View(-1):setNumInputDims(3))
  339. model.w2nn_arch_name = "vgg_7"
  340. model.w2nn_offset = 7
  341. model.w2nn_scale_factor = 1
  342. model.w2nn_channels = ch
  343. --model:cuda()
  344. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  345. return model
  346. end
  347. -- Upconvolution
  348. function srcnn.upconv_7(backend, ch)
  349. local model = nn.Sequential()
  350. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  351. model:add(nn.LeakyReLU(0.1, true))
  352. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  353. model:add(nn.LeakyReLU(0.1, true))
  354. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  355. model:add(nn.LeakyReLU(0.1, true))
  356. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  357. model:add(nn.LeakyReLU(0.1, true))
  358. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  359. model:add(nn.LeakyReLU(0.1, true))
  360. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  361. model:add(nn.LeakyReLU(0.1, true))
  362. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  363. model:add(w2nn.InplaceClip01())
  364. model:add(nn.View(-1):setNumInputDims(3))
  365. model.w2nn_arch_name = "upconv_7"
  366. model.w2nn_offset = 14
  367. model.w2nn_scale_factor = 2
  368. model.w2nn_resize = true
  369. model.w2nn_channels = ch
  370. return model
  371. end
  372. -- large version of upconv_7
  373. -- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
  374. function srcnn.upconv_7l(backend, ch)
  375. local model = nn.Sequential()
  376. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  377. model:add(nn.LeakyReLU(0.1, true))
  378. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  379. model:add(nn.LeakyReLU(0.1, true))
  380. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  381. model:add(nn.LeakyReLU(0.1, true))
  382. model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
  383. model:add(nn.LeakyReLU(0.1, true))
  384. model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
  385. model:add(nn.LeakyReLU(0.1, true))
  386. model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
  387. model:add(nn.LeakyReLU(0.1, true))
  388. model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
  389. model:add(w2nn.InplaceClip01())
  390. model:add(nn.View(-1):setNumInputDims(3))
  391. model.w2nn_arch_name = "upconv_7l"
  392. model.w2nn_offset = 14
  393. model.w2nn_scale_factor = 2
  394. model.w2nn_resize = true
  395. model.w2nn_channels = ch
  396. --model:cuda()
  397. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  398. return model
  399. end
  400. function srcnn.resnet_14l(backend, ch)
  401. local model = nn.Sequential()
  402. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  403. model:add(nn.LeakyReLU(0.1, true))
  404. model:add(ResBlock(backend, 32, 64))
  405. model:add(ResBlock(backend, 64, 64))
  406. model:add(ResBlock(backend, 64, 128))
  407. model:add(ResBlock(backend, 128, 128))
  408. model:add(ResBlock(backend, 128, 256))
  409. model:add(ResBlock(backend, 256, 256))
  410. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  411. model:add(w2nn.InplaceClip01())
  412. model:add(nn.View(-1):setNumInputDims(3))
  413. model.w2nn_arch_name = "resnet_14l"
  414. model.w2nn_offset = 28
  415. model.w2nn_scale_factor = 2
  416. model.w2nn_resize = true
  417. model.w2nn_channels = ch
  418. --model:cuda()
  419. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  420. return model
  421. end
  422. -- ResNet_with SEBlock for fast conversion
  423. function srcnn.upresnet_s(backend, ch)
  424. local model = nn.Sequential()
  425. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  426. model:add(nn.LeakyReLU(0.1, true))
  427. model:add(ResGroupSE(backend, 3, 64))
  428. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  429. model:add(nn.LeakyReLU(0.1, true))
  430. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3):noBias())
  431. model:add(w2nn.InplaceClip01())
  432. model.w2nn_arch_name = "upresnet_s"
  433. model.w2nn_offset = 18
  434. model.w2nn_scale_factor = 2
  435. model.w2nn_resize = true
  436. model.w2nn_channels = ch
  437. return model
  438. end
  439. -- Cascaded ResNet with SEBlock
  440. function srcnn.upcresnet(backend, ch)
  441. local function resnet(backend, ch, deconv)
  442. local model = nn.Sequential()
  443. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  444. model:add(nn.LeakyReLU(0.1, true))
  445. model:add(ResGroupSE(backend, 2, 64))
  446. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  447. model:add(nn.LeakyReLU(0.1, true))
  448. if deconv then
  449. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  450. else
  451. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  452. end
  453. return model
  454. end
  455. local model = nn.Sequential()
  456. local con = nn.ConcatTable()
  457. local aux_con = nn.ConcatTable()
  458. -- 2 cascade
  459. model:add(resnet(backend, ch, true))
  460. con:add(nn.Sequential():add(resnet(backend, ch, false)):add(nn.SpatialZeroPadding(-1, -1, -1, -1))) -- output size must be odd
  461. con:add(nn.SpatialZeroPadding(-8, -8, -8, -8))
  462. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01()))
  463. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01()))
  464. model:add(con)
  465. model:add(aux_con)
  466. model:add(w2nn.AuxiliaryLossTable(1))
  467. model.w2nn_arch_name = "upcresnet"
  468. model.w2nn_offset = 22
  469. model.w2nn_scale_factor = 2
  470. model.w2nn_resize = true
  471. model.w2nn_channels = ch
  472. return model
  473. end
  474. -- for segmentation
  475. function srcnn.fcn_v1(backend, ch)
  476. -- input_size = 120
  477. local model = nn.Sequential()
  478. --i = 120
  479. --model:cuda()
  480. --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
  481. model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
  482. model:add(nn.LeakyReLU(0.1, true))
  483. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  484. model:add(nn.LeakyReLU(0.1, true))
  485. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  486. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  487. model:add(nn.LeakyReLU(0.1, true))
  488. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  489. model:add(nn.LeakyReLU(0.1, true))
  490. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  491. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  492. model:add(nn.LeakyReLU(0.1, true))
  493. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  494. model:add(nn.LeakyReLU(0.1, true))
  495. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  496. model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
  497. model:add(nn.LeakyReLU(0.1, true))
  498. model:add(nn.Dropout(0.5, false, true))
  499. model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
  500. model:add(nn.LeakyReLU(0.1, true))
  501. model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
  502. model:add(nn.LeakyReLU(0.1, true))
  503. model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
  504. model:add(nn.LeakyReLU(0.1, true))
  505. model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
  506. model:add(nn.LeakyReLU(0.1, true))
  507. model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
  508. model:add(nn.LeakyReLU(0.1, true))
  509. model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
  510. model:add(w2nn.InplaceClip01())
  511. model:add(nn.View(-1):setNumInputDims(3))
  512. model.w2nn_arch_name = "fcn_v1"
  513. model.w2nn_offset = 36
  514. model.w2nn_scale_factor = 1
  515. model.w2nn_channels = ch
  516. model.w2nn_input_size = 120
  517. --model.w2nn_gcn = true
  518. return model
  519. end
  520. local function unet_branch(backend, insert, backend, n_input, n_output, depad)
  521. local block = nn.Sequential()
  522. local con = nn.ConcatTable(2)
  523. local model = nn.Sequential()
  524. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  525. block:add(nn.LeakyReLU(0.1, true))
  526. block:add(insert)
  527. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  528. block:add(nn.LeakyReLU(0.1, true))
  529. con:add(block)
  530. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  531. model:add(con)
  532. model:add(nn.CAddTable())
  533. return model
  534. end
  535. local function unet_conv(backend, n_input, n_middle, n_output, se)
  536. local model = nn.Sequential()
  537. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  538. model:add(nn.LeakyReLU(0.1, true))
  539. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  540. model:add(nn.LeakyReLU(0.1, true))
  541. if se then
  542. model:add(SEBlock(backend, n_output, 8))
  543. model:add(w2nn.ScaleTable())
  544. end
  545. return model
  546. end
  547. -- Cascaded Residual Channel Attention U-Net
  548. function srcnn.upcunet(backend, ch)
  549. -- Residual U-Net
  550. local function unet1(backend, ch, deconv)
  551. local block1 = unet_conv(backend, 64, 128, 64, true)
  552. local model = nn.Sequential()
  553. model:add(unet_conv(backend, ch, 32, 64, false))
  554. model:add(unet_branch(backend, block1, backend, 64, 64, 4))
  555. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  556. model:add(nn.LeakyReLU(0.1))
  557. if deconv then
  558. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  559. else
  560. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  561. end
  562. return model
  563. end
  564. local function unet2(backend, ch, deconv)
  565. local block1 = unet_conv(backend, 128, 256, 128, true)
  566. local block2 = nn.Sequential()
  567. block2:add(unet_conv(backend, 64, 64, 128, true))
  568. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  569. block2:add(unet_conv(backend, 128, 64, 64, true))
  570. local model = nn.Sequential()
  571. model:add(unet_conv(backend, ch, 32, 64, false))
  572. model:add(unet_branch(backend, block2, backend, 64, 64, 16))
  573. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  574. model:add(nn.LeakyReLU(0.1))
  575. if deconv then
  576. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  577. else
  578. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  579. end
  580. return model
  581. end
  582. local model = nn.Sequential()
  583. local con = nn.ConcatTable()
  584. local aux_con = nn.ConcatTable()
  585. -- 2 cascade
  586. model:add(unet1(backend, ch, true))
  587. con:add(unet2(backend, ch, false))
  588. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  589. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  590. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  591. model:add(con)
  592. model:add(aux_con)
  593. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  594. model.w2nn_arch_name = "upcunet"
  595. model.w2nn_offset = 36
  596. model.w2nn_scale_factor = 2
  597. model.w2nn_channels = ch
  598. model.w2nn_resize = true
  599. model.w2nn_valid_input_size = {}
  600. for i = 76, 512, 4 do
  601. table.insert(model.w2nn_valid_input_size, i)
  602. end
  603. return model
  604. end
  605. -- cunet for 1x
  606. function srcnn.cunet(backend, ch)
  607. local function unet1(backend, ch, deconv)
  608. local block1 = unet_conv(backend, 64, 128, 64, true)
  609. local model = nn.Sequential()
  610. model:add(unet_conv(backend, ch, 32, 64, false))
  611. model:add(unet_branch(backend, block1, backend, 64, 64, 4))
  612. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  613. model:add(nn.LeakyReLU(0.1))
  614. if deconv then
  615. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  616. else
  617. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  618. end
  619. return model
  620. end
  621. local function unet2(backend, ch, deconv)
  622. local block1 = unet_conv(backend, 128, 256, 128, true)
  623. local block2 = nn.Sequential()
  624. block2:add(unet_conv(backend, 64, 64, 128, true))
  625. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  626. block2:add(unet_conv(backend, 128, 64, 64, true))
  627. local model = nn.Sequential()
  628. model:add(unet_conv(backend, ch, 32, 64, false))
  629. model:add(unet_branch(backend, block2, backend, 64, 64, 16))
  630. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  631. model:add(nn.LeakyReLU(0.1))
  632. if deconv then
  633. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  634. else
  635. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  636. end
  637. return model
  638. end
  639. local model = nn.Sequential()
  640. local con = nn.ConcatTable()
  641. local aux_con = nn.ConcatTable()
  642. -- 2 cascade
  643. model:add(unet1(backend, ch))
  644. con:add(unet2(backend, ch))
  645. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  646. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  647. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  648. model:add(con)
  649. model:add(aux_con)
  650. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  651. model.w2nn_arch_name = "cunet"
  652. model.w2nn_offset = 28
  653. model.w2nn_scale_factor = 1
  654. model.w2nn_channels = ch
  655. model.w2nn_resize = false
  656. model.w2nn_valid_input_size = {}
  657. for i = 100, 512, 4 do
  658. table.insert(model.w2nn_valid_input_size, i)
  659. end
  660. return model
  661. end
  662. local function bench()
  663. local sys = require 'sys'
  664. cudnn.benchmark = true
  665. local model = nil
  666. local arch = {"upconv_7", "upcunet", "vgg_7", "cunet"}
  667. local backend = "cudnn"
  668. local ch = 3
  669. local batch_size = 1
  670. local output_size = 320
  671. for k = 1, #arch do
  672. model = srcnn[arch[k]](backend, ch):cuda()
  673. model:evaluate()
  674. local dummy = nil
  675. local crop_size = nil
  676. if model.w2nn_resize then
  677. crop_size = (output_size + model.w2nn_offset * 2) / 2
  678. else
  679. crop_size = (output_size + model.w2nn_offset * 2)
  680. end
  681. local dummy = torch.Tensor(batch_size, ch, output_size, output_size):zero():cuda()
  682. print(arch[k], output_size, crop_size)
  683. -- warn
  684. for i = 1, 4 do
  685. local x = torch.Tensor(batch_size, ch, crop_size, crop_size):uniform():cuda()
  686. model:forward(x)
  687. end
  688. t = sys.clock()
  689. for i = 1, 100 do
  690. local x = torch.Tensor(batch_size, ch, crop_size, crop_size):uniform():cuda()
  691. local z = model:forward(x)
  692. dummy:add(z)
  693. end
  694. print(arch[k], sys.clock() - t)
  695. model:clearState()
  696. end
  697. end
  698. function srcnn.create(model_name, backend, color)
  699. model_name = model_name or "vgg_7"
  700. backend = backend or "cunn"
  701. color = color or "rgb"
  702. local ch = 3
  703. if color == "rgb" then
  704. ch = 3
  705. elseif color == "y" then
  706. ch = 1
  707. else
  708. error("unsupported color: " .. color)
  709. end
  710. if srcnn[model_name] then
  711. local model = srcnn[model_name](backend, ch)
  712. assert(model.w2nn_offset % model.w2nn_scale_factor == 0)
  713. return model
  714. else
  715. error("unsupported model_name: " .. model_name)
  716. end
  717. end
  718. --[[
  719. local model = srcnn.resnet_s("cunn", 3):cuda()
  720. print(model)
  721. model:training()
  722. print(model:forward(torch.Tensor(1, 3, 128, 128):zero():cuda()):size())
  723. bench()
  724. os.exit()
  725. --]]
  726. return srcnn