srcnn.lua 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700
  1. require 'w2nn'
  2. -- ref: https://arxiv.org/abs/1502.01852
  3. -- ref: https://arxiv.org/abs/1501.00092
  4. -- ref: https://arxiv.org/abs/1709.01507
  5. -- ref: https://arxiv.org/abs/1505.04597
  6. local srcnn = {}
  7. local function msra_filler(mod)
  8. local fin = mod.kW * mod.kH * mod.nInputPlane
  9. local fout = mod.kW * mod.kH * mod.nOutputPlane
  10. stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
  11. mod.weight:normal(0, stdv)
  12. mod.bias:zero()
  13. end
  14. local function identity_filler(mod)
  15. assert(mod.nInputPlane <= mod.nOutputPlane)
  16. mod.weight:normal(0, 0.01)
  17. mod.bias:zero()
  18. local num_groups = mod.nInputPlane -- fixed
  19. local filler_value = num_groups / mod.nOutputPlane
  20. local in_group_size = math.floor(mod.nInputPlane / num_groups)
  21. local out_group_size = math.floor(mod.nOutputPlane / num_groups)
  22. local x = math.floor(mod.kW / 2)
  23. local y = math.floor(mod.kH / 2)
  24. for i = 0, num_groups - 1 do
  25. for j = i * out_group_size, (i + 1) * out_group_size - 1 do
  26. for k = i * in_group_size, (i + 1) * in_group_size - 1 do
  27. mod.weight[j+1][k+1][y+1][x+1] = filler_value
  28. end
  29. end
  30. end
  31. end
  32. function nn.SpatialConvolutionMM:reset(stdv)
  33. msra_filler(self)
  34. end
  35. function nn.SpatialFullConvolution:reset(stdv)
  36. msra_filler(self)
  37. end
  38. function nn.SpatialDilatedConvolution:reset(stdv)
  39. identity_filler(self)
  40. end
  41. if cudnn and cudnn.SpatialConvolution then
  42. function cudnn.SpatialConvolution:reset(stdv)
  43. msra_filler(self)
  44. end
  45. function cudnn.SpatialFullConvolution:reset(stdv)
  46. msra_filler(self)
  47. end
  48. if cudnn.SpatialDilatedConvolution then
  49. function cudnn.SpatialDilatedConvolution:reset(stdv)
  50. identity_filler(self)
  51. end
  52. end
  53. end
  54. function nn.SpatialConvolutionMM:clearState()
  55. if self.gradWeight then
  56. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  57. end
  58. if self.gradBias then
  59. self.gradBias:resize(self.nOutputPlane):zero()
  60. end
  61. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  62. end
  63. function srcnn.channels(model)
  64. if model.w2nn_channels ~= nil then
  65. return model.w2nn_channels
  66. else
  67. return model:get(model:size() - 1).weight:size(1)
  68. end
  69. end
  70. function srcnn.backend(model)
  71. local conv = model:findModules("cudnn.SpatialConvolution")
  72. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  73. if #conv > 0 or #fullconv > 0 then
  74. return "cudnn"
  75. else
  76. return "cunn"
  77. end
  78. end
  79. function srcnn.color(model)
  80. local ch = srcnn.channels(model)
  81. if ch == 3 then
  82. return "rgb"
  83. else
  84. return "y"
  85. end
  86. end
  87. function srcnn.name(model)
  88. if model.w2nn_arch_name ~= nil then
  89. return model.w2nn_arch_name
  90. else
  91. local conv = model:findModules("nn.SpatialConvolutionMM")
  92. if #conv == 0 then
  93. conv = model:findModules("cudnn.SpatialConvolution")
  94. end
  95. if #conv == 7 then
  96. return "vgg_7"
  97. elseif #conv == 12 then
  98. return "vgg_12"
  99. else
  100. error("unsupported model")
  101. end
  102. end
  103. end
  104. function srcnn.offset_size(model)
  105. if model.w2nn_offset ~= nil then
  106. return model.w2nn_offset
  107. else
  108. local name = srcnn.name(model)
  109. if name:match("vgg_") then
  110. local conv = model:findModules("nn.SpatialConvolutionMM")
  111. if #conv == 0 then
  112. conv = model:findModules("cudnn.SpatialConvolution")
  113. end
  114. local offset = 0
  115. for i = 1, #conv do
  116. offset = offset + (conv[i].kW - 1) / 2
  117. end
  118. return math.floor(offset)
  119. else
  120. error("unsupported model")
  121. end
  122. end
  123. end
  124. function srcnn.scale_factor(model)
  125. if model.w2nn_scale_factor ~= nil then
  126. return model.w2nn_scale_factor
  127. else
  128. local name = srcnn.name(model)
  129. if name == "upconv_7" then
  130. return 2
  131. elseif name == "upconv_8_4x" then
  132. return 4
  133. else
  134. return 1
  135. end
  136. end
  137. end
  138. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  139. if backend == "cunn" then
  140. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  141. elseif backend == "cudnn" then
  142. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  143. else
  144. error("unsupported backend:" .. backend)
  145. end
  146. end
  147. srcnn.SpatialConvolution = SpatialConvolution
  148. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  149. if backend == "cunn" then
  150. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  151. elseif backend == "cudnn" then
  152. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  153. else
  154. error("unsupported backend:" .. backend)
  155. end
  156. end
  157. srcnn.SpatialFullConvolution = SpatialFullConvolution
  158. local function ReLU(backend)
  159. if backend == "cunn" then
  160. return nn.ReLU(true)
  161. elseif backend == "cudnn" then
  162. return cudnn.ReLU(true)
  163. else
  164. error("unsupported backend:" .. backend)
  165. end
  166. end
  167. srcnn.ReLU = ReLU
  168. local function Sigmoid(backend)
  169. if backend == "cunn" then
  170. return nn.Sigmoid(true)
  171. elseif backend == "cudnn" then
  172. return cudnn.Sigmoid(true)
  173. else
  174. error("unsupported backend:" .. backend)
  175. end
  176. end
  177. srcnn.ReLU = ReLU
  178. local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
  179. if backend == "cunn" then
  180. return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  181. elseif backend == "cudnn" then
  182. return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  183. else
  184. error("unsupported backend:" .. backend)
  185. end
  186. end
  187. srcnn.SpatialMaxPooling = SpatialMaxPooling
  188. local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
  189. if backend == "cunn" then
  190. return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  191. elseif backend == "cudnn" then
  192. return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  193. else
  194. error("unsupported backend:" .. backend)
  195. end
  196. end
  197. srcnn.SpatialAveragePooling = SpatialAveragePooling
  198. local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  199. if backend == "cunn" then
  200. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  201. elseif backend == "cudnn" then
  202. if cudnn.SpatialDilatedConvolution then
  203. -- cudnn v 6
  204. return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  205. else
  206. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  207. end
  208. else
  209. error("unsupported backend:" .. backend)
  210. end
  211. end
  212. srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
  213. local function GlobalAveragePooling(n_output)
  214. local gap = nn.Sequential()
  215. gap:add(nn.Mean(-1, -1)):add(nn.Mean(-1, -1))
  216. gap:add(nn.View(-1, n_output, 1, 1))
  217. return gap
  218. end
  219. srcnn.GlobalAveragePooling = GlobalAveragePooling
  220. -- Squeeze and Excitation Block
  221. local function SEBlock(backend, n_output, r)
  222. local con = nn.ConcatTable(2)
  223. local attention = nn.Sequential()
  224. local n_mid = math.floor(n_output / r)
  225. attention:add(GlobalAveragePooling(n_output))
  226. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  227. attention:add(nn.ReLU(true))
  228. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  229. attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
  230. con:add(nn.Identity())
  231. con:add(attention)
  232. return con
  233. end
  234. local function SpatialSEBlock(backend, ave_size, n_output, r)
  235. local con = nn.ConcatTable(2)
  236. local attention = nn.Sequential()
  237. local n_mid = math.floor(n_output / r)
  238. attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size))
  239. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  240. attention:add(nn.ReLU(true))
  241. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  242. attention:add(nn.Sigmoid(true))
  243. attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size))
  244. con:add(nn.Identity())
  245. con:add(attention)
  246. return con
  247. end
  248. local function ResBlock(backend, i, o)
  249. local seq = nn.Sequential()
  250. local con = nn.ConcatTable()
  251. local conv = nn.Sequential()
  252. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  253. conv:add(nn.LeakyReLU(0.1, true))
  254. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  255. conv:add(nn.LeakyReLU(0.1, true))
  256. con:add(conv)
  257. if i == o then
  258. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  259. else
  260. local seq = nn.Sequential()
  261. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  262. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  263. con:add(seq)
  264. end
  265. seq:add(con)
  266. seq:add(nn.CAddTable())
  267. return seq
  268. end
  269. local function ResBlockSE(backend, i, o)
  270. local seq = nn.Sequential()
  271. local con = nn.ConcatTable()
  272. local conv = nn.Sequential()
  273. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  274. conv:add(nn.LeakyReLU(0.1, true))
  275. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  276. conv:add(nn.LeakyReLU(0.1, true))
  277. conv:add(SEBlock(backend, o, 8))
  278. conv:add(w2nn.ScaleTable())
  279. con:add(conv)
  280. if i == o then
  281. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  282. else
  283. local seq = nn.Sequential()
  284. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  285. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  286. con:add(seq)
  287. end
  288. seq:add(con)
  289. seq:add(nn.CAddTable())
  290. return seq
  291. end
  292. local function ResGroup(backend, n, n_output)
  293. local seq = nn.Sequential()
  294. local res = nn.Sequential()
  295. local con = nn.ConcatTable(2)
  296. local depad = -2 * n
  297. for i = 1, n do
  298. res:add(ResBlock(backend, n_output, n_output))
  299. end
  300. con:add(res)
  301. con:add(nn.SpatialZeroPadding(depad, depad, depad, depad))
  302. seq:add(con)
  303. seq:add(nn.CAddTable())
  304. return seq
  305. end
  306. local function ResGroupSE(backend, n, n_output)
  307. local seq = nn.Sequential()
  308. local res = nn.Sequential()
  309. local con = nn.ConcatTable(2)
  310. local depad = -2 * n
  311. for i = 1, n do
  312. res:add(ResBlockSE(backend, n_output, n_output))
  313. end
  314. con:add(res)
  315. con:add(nn.SpatialZeroPadding(depad, depad, depad, depad))
  316. seq:add(con)
  317. seq:add(nn.CAddTable())
  318. return seq
  319. end
  320. -- VGG style net(7 layers)
  321. function srcnn.vgg_7(backend, ch)
  322. local model = nn.Sequential()
  323. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  324. model:add(nn.LeakyReLU(0.1, true))
  325. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  326. model:add(nn.LeakyReLU(0.1, true))
  327. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  328. model:add(nn.LeakyReLU(0.1, true))
  329. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  330. model:add(nn.LeakyReLU(0.1, true))
  331. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  332. model:add(nn.LeakyReLU(0.1, true))
  333. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  334. model:add(nn.LeakyReLU(0.1, true))
  335. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  336. model:add(w2nn.InplaceClip01())
  337. model:add(nn.View(-1):setNumInputDims(3))
  338. model.w2nn_arch_name = "vgg_7"
  339. model.w2nn_offset = 7
  340. model.w2nn_scale_factor = 1
  341. model.w2nn_channels = ch
  342. return model
  343. end
  344. -- Upconvolution
  345. function srcnn.upconv_7(backend, ch)
  346. local model = nn.Sequential()
  347. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  348. model:add(nn.LeakyReLU(0.1, true))
  349. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  350. model:add(nn.LeakyReLU(0.1, true))
  351. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  352. model:add(nn.LeakyReLU(0.1, true))
  353. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  354. model:add(nn.LeakyReLU(0.1, true))
  355. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  356. model:add(nn.LeakyReLU(0.1, true))
  357. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  358. model:add(nn.LeakyReLU(0.1, true))
  359. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  360. model:add(w2nn.InplaceClip01())
  361. model:add(nn.View(-1):setNumInputDims(3))
  362. model.w2nn_arch_name = "upconv_7"
  363. model.w2nn_offset = 14
  364. model.w2nn_scale_factor = 2
  365. model.w2nn_resize = true
  366. model.w2nn_channels = ch
  367. return model
  368. end
  369. -- large version of upconv_7
  370. -- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
  371. function srcnn.upconv_7l(backend, ch)
  372. local model = nn.Sequential()
  373. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  374. model:add(nn.LeakyReLU(0.1, true))
  375. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  376. model:add(nn.LeakyReLU(0.1, true))
  377. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  378. model:add(nn.LeakyReLU(0.1, true))
  379. model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
  380. model:add(nn.LeakyReLU(0.1, true))
  381. model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
  382. model:add(nn.LeakyReLU(0.1, true))
  383. model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
  384. model:add(nn.LeakyReLU(0.1, true))
  385. model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
  386. model:add(w2nn.InplaceClip01())
  387. model:add(nn.View(-1):setNumInputDims(3))
  388. model.w2nn_arch_name = "upconv_7l"
  389. model.w2nn_offset = 14
  390. model.w2nn_scale_factor = 2
  391. model.w2nn_resize = true
  392. model.w2nn_channels = ch
  393. return model
  394. end
  395. function srcnn.resnet_14l(backend, ch)
  396. local model = nn.Sequential()
  397. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  398. model:add(nn.LeakyReLU(0.1, true))
  399. model:add(ResBlock(backend, 32, 64))
  400. model:add(ResBlock(backend, 64, 64))
  401. model:add(ResBlock(backend, 64, 128))
  402. model:add(ResBlock(backend, 128, 128))
  403. model:add(ResBlock(backend, 128, 256))
  404. model:add(ResBlock(backend, 256, 256))
  405. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  406. model:add(w2nn.InplaceClip01())
  407. model:add(nn.View(-1):setNumInputDims(3))
  408. model.w2nn_arch_name = "resnet_14l"
  409. model.w2nn_offset = 28
  410. model.w2nn_scale_factor = 2
  411. model.w2nn_resize = true
  412. model.w2nn_channels = ch
  413. return model
  414. end
  415. -- ResNet with SEBlock for fast conversion
  416. function srcnn.upresnet_s(backend, ch)
  417. local model = nn.Sequential()
  418. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  419. model:add(nn.LeakyReLU(0.1, true))
  420. model:add(ResGroupSE(backend, 3, 64))
  421. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  422. model:add(nn.LeakyReLU(0.1, true))
  423. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3):noBias())
  424. model:add(w2nn.InplaceClip01())
  425. model.w2nn_arch_name = "upresnet_s"
  426. model.w2nn_offset = 18
  427. model.w2nn_scale_factor = 2
  428. model.w2nn_resize = true
  429. model.w2nn_channels = ch
  430. return model
  431. end
  432. -- for segmentation
  433. function srcnn.fcn_v1(backend, ch)
  434. -- input_size = 120
  435. local model = nn.Sequential()
  436. --i = 120
  437. --model:cuda()
  438. --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
  439. model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
  440. model:add(nn.LeakyReLU(0.1, true))
  441. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  442. model:add(nn.LeakyReLU(0.1, true))
  443. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  444. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  445. model:add(nn.LeakyReLU(0.1, true))
  446. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  447. model:add(nn.LeakyReLU(0.1, true))
  448. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  449. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  450. model:add(nn.LeakyReLU(0.1, true))
  451. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  452. model:add(nn.LeakyReLU(0.1, true))
  453. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  454. model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
  455. model:add(nn.LeakyReLU(0.1, true))
  456. model:add(nn.Dropout(0.5, false, true))
  457. model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
  458. model:add(nn.LeakyReLU(0.1, true))
  459. model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
  460. model:add(nn.LeakyReLU(0.1, true))
  461. model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
  462. model:add(nn.LeakyReLU(0.1, true))
  463. model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
  464. model:add(nn.LeakyReLU(0.1, true))
  465. model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
  466. model:add(nn.LeakyReLU(0.1, true))
  467. model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
  468. model:add(w2nn.InplaceClip01())
  469. model:add(nn.View(-1):setNumInputDims(3))
  470. model.w2nn_arch_name = "fcn_v1"
  471. model.w2nn_offset = 36
  472. model.w2nn_scale_factor = 1
  473. model.w2nn_channels = ch
  474. model.w2nn_input_size = 120
  475. --model.w2nn_gcn = true
  476. return model
  477. end
  478. -- Cascaded Residual U-Net with SEBlock
  479. -- unet utils adapted from https://gist.github.com/toshi-k/ca75e614f1ac12fa44f62014ac1d6465
  480. local function unet_conv(backend, n_input, n_middle, n_output, se)
  481. local model = nn.Sequential()
  482. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  483. model:add(nn.LeakyReLU(0.1, true))
  484. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  485. model:add(nn.LeakyReLU(0.1, true))
  486. if se then
  487. model:add(SEBlock(backend, n_output, 8))
  488. model:add(w2nn.ScaleTable())
  489. end
  490. return model
  491. end
  492. local function unet_branch(backend, insert, backend, n_input, n_output, depad)
  493. local block = nn.Sequential()
  494. local con = nn.ConcatTable(2)
  495. local model = nn.Sequential()
  496. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  497. block:add(nn.LeakyReLU(0.1, true))
  498. block:add(insert)
  499. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  500. block:add(nn.LeakyReLU(0.1, true))
  501. con:add(block)
  502. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  503. model:add(con)
  504. model:add(nn.CAddTable())
  505. return model
  506. end
  507. local function cunet_unet1(backend, ch, deconv)
  508. local block1 = unet_conv(backend, 64, 128, 64, true)
  509. local model = nn.Sequential()
  510. model:add(unet_conv(backend, ch, 32, 64, false))
  511. model:add(unet_branch(backend, block1, backend, 64, 64, 4))
  512. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  513. model:add(nn.LeakyReLU(0.1))
  514. if deconv then
  515. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  516. else
  517. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  518. end
  519. return model
  520. end
  521. local function cunet_unet2(backend, ch, deconv)
  522. local block1 = unet_conv(backend, 128, 256, 128, true)
  523. local block2 = nn.Sequential()
  524. block2:add(unet_conv(backend, 64, 64, 128, true))
  525. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  526. block2:add(unet_conv(backend, 128, 64, 64, true))
  527. local model = nn.Sequential()
  528. model:add(unet_conv(backend, ch, 32, 64, false))
  529. model:add(unet_branch(backend, block2, backend, 64, 64, 16))
  530. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  531. model:add(nn.LeakyReLU(0.1))
  532. if deconv then
  533. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  534. else
  535. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  536. end
  537. return model
  538. end
  539. -- 2x
  540. function srcnn.upcunet(backend, ch)
  541. local model = nn.Sequential()
  542. local con = nn.ConcatTable()
  543. local aux_con = nn.ConcatTable()
  544. -- 2 cascade
  545. model:add(cunet_unet1(backend, ch, true))
  546. con:add(cunet_unet2(backend, ch, false))
  547. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  548. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  549. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  550. model:add(con)
  551. model:add(aux_con)
  552. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  553. model.w2nn_arch_name = "upcunet"
  554. model.w2nn_offset = 36
  555. model.w2nn_scale_factor = 2
  556. model.w2nn_channels = ch
  557. model.w2nn_resize = true
  558. model.w2nn_valid_input_size = {}
  559. for i = 76, 512, 4 do
  560. table.insert(model.w2nn_valid_input_size, i)
  561. end
  562. return model
  563. end
  564. -- 1x
  565. function srcnn.cunet(backend, ch)
  566. local model = nn.Sequential()
  567. local con = nn.ConcatTable()
  568. local aux_con = nn.ConcatTable()
  569. -- 2 cascade
  570. model:add(cunet_unet1(backend, ch, false))
  571. con:add(cunet_unet2(backend, ch, false))
  572. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  573. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  574. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  575. model:add(con)
  576. model:add(aux_con)
  577. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  578. model.w2nn_arch_name = "cunet"
  579. model.w2nn_offset = 28
  580. model.w2nn_scale_factor = 1
  581. model.w2nn_channels = ch
  582. model.w2nn_resize = false
  583. model.w2nn_valid_input_size = {}
  584. for i = 100, 512, 4 do
  585. table.insert(model.w2nn_valid_input_size, i)
  586. end
  587. return model
  588. end
  589. local function bench()
  590. local sys = require 'sys'
  591. cudnn.benchmark = true
  592. local model = nil
  593. local arch = {"upconv_7", "upcunet", "vgg_7", "cunet"}
  594. local backend = "cudnn"
  595. local ch = 3
  596. local batch_size = 1
  597. local output_size = 256
  598. for k = 1, #arch do
  599. model = srcnn[arch[k]](backend, ch):cuda()
  600. model:evaluate()
  601. local dummy = nil
  602. local crop_size = nil
  603. if model.w2nn_resize then
  604. crop_size = (output_size + model.w2nn_offset * 2) / 2
  605. else
  606. crop_size = (output_size + model.w2nn_offset * 2)
  607. end
  608. local dummy = torch.Tensor(batch_size, ch, output_size, output_size):zero():cuda()
  609. print(arch[k], output_size, crop_size)
  610. -- warn
  611. for i = 1, 4 do
  612. local x = torch.Tensor(batch_size, ch, crop_size, crop_size):uniform():cuda()
  613. model:forward(x)
  614. end
  615. t = sys.clock()
  616. for i = 1, 10 do
  617. local x = torch.Tensor(batch_size, ch, crop_size, crop_size):uniform():cuda()
  618. local z = model:forward(x)
  619. dummy:add(z)
  620. end
  621. print(arch[k], sys.clock() - t)
  622. model:clearState()
  623. end
  624. end
  625. function srcnn.create(model_name, backend, color)
  626. model_name = model_name or "vgg_7"
  627. backend = backend or "cunn"
  628. color = color or "rgb"
  629. local ch = 3
  630. if color == "rgb" then
  631. ch = 3
  632. elseif color == "y" then
  633. ch = 1
  634. else
  635. error("unsupported color: " .. color)
  636. end
  637. if srcnn[model_name] then
  638. local model = srcnn[model_name](backend, ch)
  639. assert(model.w2nn_offset % model.w2nn_scale_factor == 0)
  640. return model
  641. else
  642. error("unsupported model_name: " .. model_name)
  643. end
  644. end
  645. --[[
  646. local model = srcnn.resnet_s("cunn", 3):cuda()
  647. print(model)
  648. model:training()
  649. print(model:forward(torch.Tensor(1, 3, 128, 128):zero():cuda()):size())
  650. bench()
  651. os.exit()
  652. --]]
  653. return srcnn