srcnn.lua 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. local function msra_filler(mod)
  6. local fin = mod.kW * mod.kH * mod.nInputPlane
  7. local fout = mod.kW * mod.kH * mod.nOutputPlane
  8. stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
  9. mod.weight:normal(0, stdv)
  10. mod.bias:zero()
  11. end
  12. local function identity_filler(mod)
  13. assert(mod.nInputPlane <= mod.nOutputPlane)
  14. mod.weight:normal(0, 0.01)
  15. mod.bias:zero()
  16. local num_groups = mod.nInputPlane -- fixed
  17. local filler_value = num_groups / mod.nOutputPlane
  18. local in_group_size = math.floor(mod.nInputPlane / num_groups)
  19. local out_group_size = math.floor(mod.nOutputPlane / num_groups)
  20. local x = math.floor(mod.kW / 2)
  21. local y = math.floor(mod.kH / 2)
  22. for i = 0, num_groups - 1 do
  23. for j = i * out_group_size, (i + 1) * out_group_size - 1 do
  24. for k = i * in_group_size, (i + 1) * in_group_size - 1 do
  25. mod.weight[j+1][k+1][y+1][x+1] = filler_value
  26. end
  27. end
  28. end
  29. end
  30. function nn.SpatialConvolutionMM:reset(stdv)
  31. msra_filler(self)
  32. end
  33. function nn.SpatialFullConvolution:reset(stdv)
  34. msra_filler(self)
  35. end
  36. function nn.SpatialDilatedConvolution:reset(stdv)
  37. identity_filler(self)
  38. end
  39. if cudnn and cudnn.SpatialConvolution then
  40. function cudnn.SpatialConvolution:reset(stdv)
  41. msra_filler(self)
  42. end
  43. function cudnn.SpatialFullConvolution:reset(stdv)
  44. msra_filler(self)
  45. end
  46. if cudnn.SpatialDilatedConvolution then
  47. function cudnn.SpatialDilatedConvolution:reset(stdv)
  48. identity_filler(self)
  49. end
  50. end
  51. end
  52. function nn.SpatialConvolutionMM:clearState()
  53. if self.gradWeight then
  54. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  55. end
  56. if self.gradBias then
  57. self.gradBias:resize(self.nOutputPlane):zero()
  58. end
  59. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  60. end
  61. function srcnn.channels(model)
  62. if model.w2nn_channels ~= nil then
  63. return model.w2nn_channels
  64. else
  65. return model:get(model:size() - 1).weight:size(1)
  66. end
  67. end
  68. function srcnn.backend(model)
  69. local conv = model:findModules("cudnn.SpatialConvolution")
  70. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  71. if #conv > 0 or #fullconv > 0 then
  72. return "cudnn"
  73. else
  74. return "cunn"
  75. end
  76. end
  77. function srcnn.color(model)
  78. local ch = srcnn.channels(model)
  79. if ch == 3 then
  80. return "rgb"
  81. else
  82. return "y"
  83. end
  84. end
  85. function srcnn.name(model)
  86. if model.w2nn_arch_name ~= nil then
  87. return model.w2nn_arch_name
  88. else
  89. local conv = model:findModules("nn.SpatialConvolutionMM")
  90. if #conv == 0 then
  91. conv = model:findModules("cudnn.SpatialConvolution")
  92. end
  93. if #conv == 7 then
  94. return "vgg_7"
  95. elseif #conv == 12 then
  96. return "vgg_12"
  97. else
  98. error("unsupported model")
  99. end
  100. end
  101. end
  102. function srcnn.offset_size(model)
  103. if model.w2nn_offset ~= nil then
  104. return model.w2nn_offset
  105. else
  106. local name = srcnn.name(model)
  107. if name:match("vgg_") then
  108. local conv = model:findModules("nn.SpatialConvolutionMM")
  109. if #conv == 0 then
  110. conv = model:findModules("cudnn.SpatialConvolution")
  111. end
  112. local offset = 0
  113. for i = 1, #conv do
  114. offset = offset + (conv[i].kW - 1) / 2
  115. end
  116. return math.floor(offset)
  117. else
  118. error("unsupported model")
  119. end
  120. end
  121. end
  122. function srcnn.scale_factor(model)
  123. if model.w2nn_scale_factor ~= nil then
  124. return model.w2nn_scale_factor
  125. else
  126. local name = srcnn.name(model)
  127. if name == "upconv_7" then
  128. return 2
  129. elseif name == "upconv_8_4x" then
  130. return 4
  131. else
  132. return 1
  133. end
  134. end
  135. end
  136. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  137. if backend == "cunn" then
  138. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  139. elseif backend == "cudnn" then
  140. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  141. else
  142. error("unsupported backend:" .. backend)
  143. end
  144. end
  145. srcnn.SpatialConvolution = SpatialConvolution
  146. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  147. if backend == "cunn" then
  148. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  149. elseif backend == "cudnn" then
  150. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  151. else
  152. error("unsupported backend:" .. backend)
  153. end
  154. end
  155. srcnn.SpatialFullConvolution = SpatialFullConvolution
  156. local function ReLU(backend)
  157. if backend == "cunn" then
  158. return nn.ReLU(true)
  159. elseif backend == "cudnn" then
  160. return cudnn.ReLU(true)
  161. else
  162. error("unsupported backend:" .. backend)
  163. end
  164. end
  165. srcnn.ReLU = ReLU
  166. local function Sigmoid(backend)
  167. if backend == "cunn" then
  168. return nn.Sigmoid(true)
  169. elseif backend == "cudnn" then
  170. return cudnn.Sigmoid(true)
  171. else
  172. error("unsupported backend:" .. backend)
  173. end
  174. end
  175. srcnn.ReLU = ReLU
  176. local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
  177. if backend == "cunn" then
  178. return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  179. elseif backend == "cudnn" then
  180. return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  181. else
  182. error("unsupported backend:" .. backend)
  183. end
  184. end
  185. srcnn.SpatialMaxPooling = SpatialMaxPooling
  186. local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
  187. if backend == "cunn" then
  188. return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  189. elseif backend == "cudnn" then
  190. return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  191. else
  192. error("unsupported backend:" .. backend)
  193. end
  194. end
  195. srcnn.SpatialAveragePooling = SpatialAveragePooling
  196. local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  197. if backend == "cunn" then
  198. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  199. elseif backend == "cudnn" then
  200. if cudnn.SpatialDilatedConvolution then
  201. -- cudnn v 6
  202. return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  203. else
  204. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  205. end
  206. else
  207. error("unsupported backend:" .. backend)
  208. end
  209. end
  210. srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
  211. local function GlobalAveragePooling(n_output)
  212. local gap = nn.Sequential()
  213. gap:add(nn.Mean(-1, -1)):add(nn.Mean(-1, -1))
  214. gap:add(nn.View(-1, n_output, 1, 1))
  215. return gap
  216. end
  217. srcnn.GlobalAveragePooling = GlobalAveragePooling
  218. -- Squeeze and Excitation Block
  219. local function SEBlock(backend, n_output, r)
  220. local con = nn.ConcatTable(2)
  221. local attention = nn.Sequential()
  222. local n_mid = math.floor(n_output / r)
  223. attention:add(GlobalAveragePooling(n_output))
  224. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  225. attention:add(nn.ReLU(true))
  226. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  227. attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
  228. con:add(nn.Identity())
  229. con:add(attention)
  230. return con
  231. end
  232. -- I devised this arch for the block size and global average pooling problem,
  233. -- but SEBlock may possibly learn multi-scale input or just a normalization. No problems occur.
  234. -- So this arch is not used.
  235. local function SpatialSEBlock(backend, ave_size, n_output, r)
  236. local con = nn.ConcatTable(2)
  237. local attention = nn.Sequential()
  238. local n_mid = math.floor(n_output / r)
  239. attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size))
  240. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  241. attention:add(nn.ReLU(true))
  242. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  243. attention:add(nn.Sigmoid(true))
  244. attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size))
  245. con:add(nn.Identity())
  246. con:add(attention)
  247. return con
  248. end
  249. local function ResBlock(backend, i, o)
  250. local seq = nn.Sequential()
  251. local con = nn.ConcatTable()
  252. local conv = nn.Sequential()
  253. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  254. conv:add(nn.LeakyReLU(0.1, true))
  255. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  256. conv:add(nn.LeakyReLU(0.1, true))
  257. con:add(conv)
  258. if i == o then
  259. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  260. else
  261. local seq = nn.Sequential()
  262. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  263. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  264. con:add(seq)
  265. end
  266. seq:add(con)
  267. seq:add(nn.CAddTable())
  268. return seq
  269. end
  270. local function ResBlockSE(backend, i, o)
  271. local seq = nn.Sequential()
  272. local con = nn.ConcatTable()
  273. local conv = nn.Sequential()
  274. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  275. conv:add(nn.LeakyReLU(0.1, true))
  276. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  277. conv:add(nn.LeakyReLU(0.1, true))
  278. conv:add(SEBlock(backend, o, 8))
  279. conv:add(w2nn.ScaleTable())
  280. con:add(conv)
  281. if i == o then
  282. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  283. else
  284. local seq = nn.Sequential()
  285. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  286. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  287. con:add(seq)
  288. end
  289. seq:add(con)
  290. seq:add(nn.CAddTable())
  291. return seq
  292. end
  293. local function ResGroup(backend, n, n_output)
  294. local seq = nn.Sequential()
  295. local res = nn.Sequential()
  296. local con = nn.ConcatTable(2)
  297. local depad = -2 * n
  298. for i = 1, n do
  299. res:add(ResBlock(backend, n_output, n_output))
  300. end
  301. con:add(res)
  302. con:add(nn.SpatialZeroPadding(depad, depad, depad, depad))
  303. seq:add(con)
  304. seq:add(nn.CAddTable())
  305. return seq
  306. end
  307. local function ResGroupSE(backend, n, n_output)
  308. local seq = nn.Sequential()
  309. local res = nn.Sequential()
  310. local con = nn.ConcatTable(2)
  311. local depad = -2 * n
  312. for i = 1, n do
  313. res:add(ResBlockSE(backend, n_output, n_output))
  314. end
  315. con:add(res)
  316. con:add(nn.SpatialZeroPadding(depad, depad, depad, depad))
  317. seq:add(con)
  318. seq:add(nn.CAddTable())
  319. return seq
  320. end
  321. -- VGG style net(7 layers)
  322. function srcnn.vgg_7(backend, ch)
  323. local model = nn.Sequential()
  324. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  325. model:add(nn.LeakyReLU(0.1, true))
  326. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  327. model:add(nn.LeakyReLU(0.1, true))
  328. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  329. model:add(nn.LeakyReLU(0.1, true))
  330. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  331. model:add(nn.LeakyReLU(0.1, true))
  332. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  333. model:add(nn.LeakyReLU(0.1, true))
  334. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  335. model:add(nn.LeakyReLU(0.1, true))
  336. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  337. model:add(w2nn.InplaceClip01())
  338. model:add(nn.View(-1):setNumInputDims(3))
  339. model.w2nn_arch_name = "vgg_7"
  340. model.w2nn_offset = 7
  341. model.w2nn_scale_factor = 1
  342. model.w2nn_channels = ch
  343. --model:cuda()
  344. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  345. return model
  346. end
  347. -- Upconvolution
  348. function srcnn.upconv_7(backend, ch)
  349. local model = nn.Sequential()
  350. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  351. model:add(nn.LeakyReLU(0.1, true))
  352. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  353. model:add(nn.LeakyReLU(0.1, true))
  354. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  355. model:add(nn.LeakyReLU(0.1, true))
  356. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  357. model:add(nn.LeakyReLU(0.1, true))
  358. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  359. model:add(nn.LeakyReLU(0.1, true))
  360. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  361. model:add(nn.LeakyReLU(0.1, true))
  362. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  363. model:add(w2nn.InplaceClip01())
  364. model:add(nn.View(-1):setNumInputDims(3))
  365. model.w2nn_arch_name = "upconv_7"
  366. model.w2nn_offset = 14
  367. model.w2nn_scale_factor = 2
  368. model.w2nn_resize = true
  369. model.w2nn_channels = ch
  370. return model
  371. end
  372. -- large version of upconv_7
  373. -- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
  374. function srcnn.upconv_7l(backend, ch)
  375. local model = nn.Sequential()
  376. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  377. model:add(nn.LeakyReLU(0.1, true))
  378. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  379. model:add(nn.LeakyReLU(0.1, true))
  380. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  381. model:add(nn.LeakyReLU(0.1, true))
  382. model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
  383. model:add(nn.LeakyReLU(0.1, true))
  384. model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
  385. model:add(nn.LeakyReLU(0.1, true))
  386. model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
  387. model:add(nn.LeakyReLU(0.1, true))
  388. model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
  389. model:add(w2nn.InplaceClip01())
  390. model:add(nn.View(-1):setNumInputDims(3))
  391. model.w2nn_arch_name = "upconv_7l"
  392. model.w2nn_offset = 14
  393. model.w2nn_scale_factor = 2
  394. model.w2nn_resize = true
  395. model.w2nn_channels = ch
  396. --model:cuda()
  397. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  398. return model
  399. end
  400. function srcnn.resnet_14l(backend, ch)
  401. local model = nn.Sequential()
  402. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  403. model:add(nn.LeakyReLU(0.1, true))
  404. model:add(ResBlock(backend, 32, 64))
  405. model:add(ResBlock(backend, 64, 64))
  406. model:add(ResBlock(backend, 64, 128))
  407. model:add(ResBlock(backend, 128, 128))
  408. model:add(ResBlock(backend, 128, 256))
  409. model:add(ResBlock(backend, 256, 256))
  410. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  411. model:add(w2nn.InplaceClip01())
  412. model:add(nn.View(-1):setNumInputDims(3))
  413. model.w2nn_arch_name = "resnet_14l"
  414. model.w2nn_offset = 28
  415. model.w2nn_scale_factor = 2
  416. model.w2nn_resize = true
  417. model.w2nn_channels = ch
  418. --model:cuda()
  419. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  420. return model
  421. end
  422. -- ResNet_with SEBlock for fast conversion
  423. function srcnn.upresnet_s(backend, ch)
  424. local model = nn.Sequential()
  425. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  426. model:add(nn.LeakyReLU(0.1, true))
  427. model:add(ResGroupSE(backend, 3, 64))
  428. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  429. model:add(nn.LeakyReLU(0.1, true))
  430. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3):noBias())
  431. model:add(w2nn.InplaceClip01())
  432. model.w2nn_arch_name = "upresnet_s"
  433. model.w2nn_offset = 18
  434. model.w2nn_scale_factor = 2
  435. model.w2nn_resize = true
  436. model.w2nn_channels = ch
  437. return model
  438. end
  439. -- Cascaded ResNet with SEBlock
  440. function srcnn.upcresnet(backend, ch)
  441. local function resnet(backend, ch, deconv)
  442. local model = nn.Sequential()
  443. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  444. model:add(nn.LeakyReLU(0.1, true))
  445. model:add(ResGroupSE(backend, 2, 64))
  446. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  447. model:add(nn.LeakyReLU(0.1, true))
  448. if deconv then
  449. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  450. else
  451. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  452. end
  453. return model
  454. end
  455. local model = nn.Sequential()
  456. local con = nn.ConcatTable()
  457. local aux_con = nn.ConcatTable()
  458. -- 2 cascade
  459. model:add(resnet(backend, ch, true))
  460. con:add(nn.Sequential():add(resnet(backend, ch, false)):add(nn.SpatialZeroPadding(-1, -1, -1, -1))) -- output is odd
  461. con:add(nn.SpatialZeroPadding(-8, -8, -8, -8))
  462. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  463. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  464. model:add(con)
  465. model:add(aux_con)
  466. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  467. model.w2nn_arch_name = "upcresnet"
  468. model.w2nn_offset = 22
  469. model.w2nn_scale_factor = 2
  470. model.w2nn_resize = true
  471. model.w2nn_channels = ch
  472. return model
  473. end
  474. -- for segmentation
  475. function srcnn.fcn_v1(backend, ch)
  476. -- input_size = 120
  477. local model = nn.Sequential()
  478. --i = 120
  479. --model:cuda()
  480. --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
  481. model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
  482. model:add(nn.LeakyReLU(0.1, true))
  483. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  484. model:add(nn.LeakyReLU(0.1, true))
  485. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  486. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  487. model:add(nn.LeakyReLU(0.1, true))
  488. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  489. model:add(nn.LeakyReLU(0.1, true))
  490. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  491. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  492. model:add(nn.LeakyReLU(0.1, true))
  493. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  494. model:add(nn.LeakyReLU(0.1, true))
  495. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  496. model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
  497. model:add(nn.LeakyReLU(0.1, true))
  498. model:add(nn.Dropout(0.5, false, true))
  499. model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
  500. model:add(nn.LeakyReLU(0.1, true))
  501. model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
  502. model:add(nn.LeakyReLU(0.1, true))
  503. model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
  504. model:add(nn.LeakyReLU(0.1, true))
  505. model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
  506. model:add(nn.LeakyReLU(0.1, true))
  507. model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
  508. model:add(nn.LeakyReLU(0.1, true))
  509. model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
  510. model:add(w2nn.InplaceClip01())
  511. model:add(nn.View(-1):setNumInputDims(3))
  512. model.w2nn_arch_name = "fcn_v1"
  513. model.w2nn_offset = 36
  514. model.w2nn_scale_factor = 1
  515. model.w2nn_channels = ch
  516. model.w2nn_input_size = 120
  517. --model.w2nn_gcn = true
  518. return model
  519. end
  520. local function unet_branch(backend, insert, backend, n_input, n_output, depad)
  521. local block = nn.Sequential()
  522. local con = nn.ConcatTable(2)
  523. local model = nn.Sequential()
  524. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  525. block:add(nn.LeakyReLU(0.1, true))
  526. block:add(insert)
  527. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  528. block:add(nn.LeakyReLU(0.1, true))
  529. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  530. con:add(block)
  531. model:add(con)
  532. model:add(nn.CAddTable())
  533. return model
  534. end
  535. local function unet_conv(backend, n_input, n_middle, n_output, se)
  536. local model = nn.Sequential()
  537. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  538. model:add(nn.LeakyReLU(0.1, true))
  539. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  540. model:add(nn.LeakyReLU(0.1, true))
  541. if se then
  542. model:add(SEBlock(backend, n_output, 8))
  543. model:add(w2nn.ScaleTable())
  544. end
  545. return model
  546. end
  547. -- Cascaded Residual Channel Attention U-Net
  548. function srcnn.upcunet(backend, ch)
  549. -- Residual U-Net
  550. local function unet(backend, ch, deconv)
  551. local block1 = unet_conv(backend, 128, 256, 128, true)
  552. local block2 = nn.Sequential()
  553. block2:add(unet_conv(backend, 64, 64, 128, true))
  554. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  555. block2:add(unet_conv(backend, 128, 64, 64, true))
  556. local model = nn.Sequential()
  557. model:add(unet_conv(backend, ch, 32, 64, false))
  558. model:add(unet_branch(backend, block2, backend, 64, 64, 16))
  559. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  560. model:add(nn.LeakyReLU(0.1))
  561. if deconv then
  562. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  563. else
  564. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  565. end
  566. return model
  567. end
  568. local model = nn.Sequential()
  569. local con = nn.ConcatTable()
  570. local aux_con = nn.ConcatTable()
  571. -- 2 cascade
  572. model:add(unet(backend, ch, true))
  573. con:add(unet(backend, ch, false))
  574. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  575. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  576. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  577. model:add(con)
  578. model:add(aux_con)
  579. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  580. model.w2nn_arch_name = "upcunet"
  581. model.w2nn_offset = 60
  582. model.w2nn_scale_factor = 2
  583. model.w2nn_channels = ch
  584. model.w2nn_resize = true
  585. model.w2nn_valid_input_size = {}
  586. for i = 76, 512, 4 do
  587. table.insert(model.w2nn_valid_input_size, i)
  588. end
  589. return model
  590. end
  591. -- cunet for 1x
  592. function srcnn.cunet(backend, ch)
  593. local function unet(backend, ch)
  594. local block1 = unet_conv(backend, 128, 256, 128, true)
  595. local block2 = nn.Sequential()
  596. block2:add(unet_conv(backend, 64, 64, 128, true))
  597. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  598. block2:add(unet_conv(backend, 128, 64, 64, true))
  599. local model = nn.Sequential()
  600. model:add(unet_conv(backend, ch, 32, 64, false))
  601. model:add(unet_branch(backend, block2, backend, 64, 64, 16))
  602. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  603. model:add(nn.LeakyReLU(0.1))
  604. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  605. return model
  606. end
  607. local model = nn.Sequential()
  608. local con = nn.ConcatTable()
  609. local aux_con = nn.ConcatTable()
  610. -- 2 cascade
  611. model:add(unet(backend, ch))
  612. con:add(unet(backend, ch))
  613. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  614. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  615. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  616. model:add(con)
  617. model:add(aux_con)
  618. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  619. model.w2nn_arch_name = "cunet"
  620. model.w2nn_offset = 40
  621. model.w2nn_scale_factor = 1
  622. model.w2nn_channels = ch
  623. model.w2nn_resize = false
  624. model.w2nn_valid_input_size = {}
  625. for i = 100, 512, 4 do
  626. table.insert(model.w2nn_valid_input_size, i)
  627. end
  628. return model
  629. end
  630. function srcnn.upcunet_s_p0(backend, ch)
  631. -- Residual U-Net
  632. local function unet1(backend, ch, deconv)
  633. local block1 = unet_conv(backend, 64, 128, 64, true)
  634. local model = nn.Sequential()
  635. model:add(unet_conv(backend, ch, 32, 64, false))
  636. model:add(unet_branch(backend, block1, backend, 64, 64, 4))
  637. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  638. model:add(nn.LeakyReLU(0.1))
  639. if deconv then
  640. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  641. else
  642. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  643. end
  644. return model
  645. end
  646. local model = nn.Sequential()
  647. local con = nn.ConcatTable()
  648. local aux_con = nn.ConcatTable()
  649. -- 2 cascade
  650. model:add(unet1(backend, ch, true))
  651. con:add(unet1(backend, ch, false))
  652. con:add(nn.SpatialZeroPadding(-8, -8, -8, -8))
  653. --con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  654. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  655. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  656. model:add(con)
  657. model:add(aux_con)
  658. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  659. model.w2nn_arch_name = "upcunet_s_p0"
  660. model.w2nn_offset = 24
  661. model.w2nn_scale_factor = 2
  662. model.w2nn_channels = ch
  663. model.w2nn_resize = true
  664. model.w2nn_valid_input_size = {}
  665. for i = 76, 512, 4 do
  666. table.insert(model.w2nn_valid_input_size, i)
  667. end
  668. return model
  669. end
  670. function srcnn.upcunet_s_p1(backend, ch)
  671. -- Residual U-Net
  672. local function unet1(backend, ch, deconv)
  673. local block1 = unet_conv(backend, 64, 128, 64, true)
  674. local model = nn.Sequential()
  675. model:add(unet_conv(backend, ch, 32, 64, false))
  676. model:add(unet_branch(backend, block1, backend, 64, 64, 4))
  677. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  678. model:add(nn.LeakyReLU(0.1))
  679. if deconv then
  680. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  681. else
  682. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  683. end
  684. return model
  685. end
  686. local function unet2(backend, ch, deconv)
  687. local block1 = unet_conv(backend, 128, 256, 128, true)
  688. local block2 = nn.Sequential()
  689. block2:add(unet_conv(backend, 64, 64, 128, true))
  690. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  691. block2:add(unet_conv(backend, 128, 64, 64, true))
  692. local model = nn.Sequential()
  693. model:add(unet_conv(backend, ch, 32, 64, false))
  694. model:add(unet_branch(backend, block2, backend, 64, 64, 16))
  695. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  696. model:add(nn.LeakyReLU(0.1))
  697. if deconv then
  698. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  699. else
  700. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  701. end
  702. return model
  703. end
  704. local model = nn.Sequential()
  705. local con = nn.ConcatTable()
  706. local aux_con = nn.ConcatTable()
  707. -- 2 cascade
  708. model:add(unet1(backend, ch, true))
  709. con:add(unet2(backend, ch, false))
  710. --con:add(nn.SpatialZeroPadding(-8, -8, -8, -8))
  711. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  712. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  713. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  714. model:add(con)
  715. model:add(aux_con)
  716. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  717. model.w2nn_arch_name = "upcunet_s_p1"
  718. model.w2nn_offset = 36
  719. model.w2nn_scale_factor = 2
  720. model.w2nn_channels = ch
  721. model.w2nn_resize = true
  722. model.w2nn_valid_input_size = {}
  723. for i = 76, 512, 4 do
  724. table.insert(model.w2nn_valid_input_size, i)
  725. end
  726. return model
  727. end
  728. function srcnn.upcunet_s_p2(backend, ch)
  729. -- Residual U-Net
  730. local function unet1(backend, ch, deconv)
  731. local block1 = unet_conv(backend, 64, 128, 64, true)
  732. local model = nn.Sequential()
  733. model:add(unet_conv(backend, ch, 32, 64, false))
  734. model:add(unet_branch(backend, block1, backend, 64, 64, 4))
  735. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  736. model:add(nn.LeakyReLU(0.1))
  737. if deconv then
  738. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  739. else
  740. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  741. end
  742. return model
  743. end
  744. local function unet2(backend, ch, deconv)
  745. local block1 = unet_conv(backend, 128, 256, 128, true)
  746. local block2 = nn.Sequential()
  747. block2:add(unet_conv(backend, 64, 64, 128, true))
  748. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  749. block2:add(unet_conv(backend, 128, 64, 64, true))
  750. local model = nn.Sequential()
  751. model:add(unet_conv(backend, ch, 32, 64, false))
  752. model:add(unet_branch(backend, block2, backend, 64, 64, 16))
  753. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  754. model:add(nn.LeakyReLU(0.1))
  755. if deconv then
  756. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  757. else
  758. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  759. end
  760. return model
  761. end
  762. local model = nn.Sequential()
  763. local con = nn.ConcatTable()
  764. local aux_con = nn.ConcatTable()
  765. -- 2 cascade
  766. model:add(unet2(backend, ch, true))
  767. con:add(unet1(backend, ch, false))
  768. con:add(nn.SpatialZeroPadding(-8, -8, -8, -8))
  769. --con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  770. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  771. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  772. model:add(con)
  773. model:add(aux_con)
  774. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  775. model.w2nn_arch_name = "upcunet_s_p2"
  776. model.w2nn_offset = 48
  777. model.w2nn_scale_factor = 2
  778. model.w2nn_channels = ch
  779. model.w2nn_resize = true
  780. model.w2nn_valid_input_size = {}
  781. for i = 76, 512, 4 do
  782. table.insert(model.w2nn_valid_input_size, i)
  783. end
  784. return model
  785. end
  786. function srcnn.cunet_s(backend, ch)
  787. local function unet(backend, ch)
  788. local block1 = unet_conv(backend, 128, 256, 128, true)
  789. local block2 = nn.Sequential()
  790. block2:add(unet_conv(backend, 32, 64, 128, true))
  791. block2:add(unet_branch(backend, block1, backend, 128, 128, 4))
  792. block2:add(unet_conv(backend, 128, 64, 32, true))
  793. local model = nn.Sequential()
  794. model:add(unet_conv(backend, ch, 32, 32, false))
  795. model:add(unet_branch(backend, block2, backend, 32, 32, 16))
  796. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  797. model:add(nn.LeakyReLU(0.1))
  798. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  799. return model
  800. end
  801. local model = nn.Sequential()
  802. local con = nn.ConcatTable()
  803. local aux_con = nn.ConcatTable()
  804. -- 2 cascade
  805. model:add(unet(backend, ch))
  806. con:add(unet(backend, ch))
  807. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  808. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  809. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  810. model:add(con)
  811. model:add(aux_con)
  812. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  813. model.w2nn_arch_name = "cunet_s"
  814. model.w2nn_offset = 40
  815. model.w2nn_scale_factor = 1
  816. model.w2nn_channels = ch
  817. model.w2nn_resize = false
  818. model.w2nn_valid_input_size = {}
  819. for i = 100, 512, 4 do
  820. table.insert(model.w2nn_valid_input_size, i)
  821. end
  822. return model
  823. end
  824. local function bench()
  825. local sys = require 'sys'
  826. cudnn.benchmark = true
  827. local model = nil
  828. local arch = {"upconv_7", "upresnet_s","upcresnet", "resnet_14l", "upcunet", "upcunet_s_p0", "upcunet_s_p1", "upcunet_s_p2"}
  829. --local arch = {"upconv_7", "upcunet","upcunet_v0", "upcunet_s", "vgg_7", "cunet", "cunet_s"}
  830. local backend = "cudnn"
  831. local ch = 3
  832. local batch_size = 1
  833. local output_size = 320
  834. for k = 1, #arch do
  835. model = srcnn[arch[k]](backend, ch):cuda()
  836. model:evaluate()
  837. local dummy = nil
  838. local crop_size = (output_size + model.w2nn_offset * 2) / 2
  839. local dummy = torch.Tensor(batch_size, ch, output_size, output_size):zero():cuda()
  840. print(arch[k], output_size, crop_size)
  841. -- warn
  842. for i = 1, 4 do
  843. local x = torch.Tensor(batch_size, ch, crop_size, crop_size):uniform():cuda()
  844. model:forward(x)
  845. end
  846. t = sys.clock()
  847. for i = 1, 100 do
  848. local x = torch.Tensor(batch_size, ch, crop_size, crop_size):uniform():cuda()
  849. local z = model:forward(x)
  850. dummy:add(z)
  851. end
  852. print(arch[k], sys.clock() - t)
  853. model:clearState()
  854. end
  855. end
  856. function srcnn.create(model_name, backend, color)
  857. model_name = model_name or "vgg_7"
  858. backend = backend or "cunn"
  859. color = color or "rgb"
  860. local ch = 3
  861. if color == "rgb" then
  862. ch = 3
  863. elseif color == "y" then
  864. ch = 1
  865. else
  866. error("unsupported color: " .. color)
  867. end
  868. if srcnn[model_name] then
  869. local model = srcnn[model_name](backend, ch)
  870. assert(model.w2nn_offset % model.w2nn_scale_factor == 0)
  871. return model
  872. else
  873. error("unsupported model_name: " .. model_name)
  874. end
  875. end
  876. --[[
  877. local model = srcnn.resnet_s("cunn", 3):cuda()
  878. print(model)
  879. model:training()
  880. print(model:forward(torch.Tensor(1, 3, 128, 128):zero():cuda()):size())
  881. bench()
  882. os.exit()
  883. --]]
  884. return srcnn