srcnn.lua 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. local function msra_filler(mod)
  6. local fin = mod.kW * mod.kH * mod.nInputPlane
  7. local fout = mod.kW * mod.kH * mod.nOutputPlane
  8. stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
  9. mod.weight:normal(0, stdv)
  10. mod.bias:zero()
  11. end
  12. local function identity_filler(mod)
  13. assert(mod.nInputPlane <= mod.nOutputPlane)
  14. mod.weight:normal(0, 0.01)
  15. mod.bias:zero()
  16. local num_groups = mod.nInputPlane -- fixed
  17. local filler_value = num_groups / mod.nOutputPlane
  18. local in_group_size = math.floor(mod.nInputPlane / num_groups)
  19. local out_group_size = math.floor(mod.nOutputPlane / num_groups)
  20. local x = math.floor(mod.kW / 2)
  21. local y = math.floor(mod.kH / 2)
  22. for i = 0, num_groups - 1 do
  23. for j = i * out_group_size, (i + 1) * out_group_size - 1 do
  24. for k = i * in_group_size, (i + 1) * in_group_size - 1 do
  25. mod.weight[j+1][k+1][y+1][x+1] = filler_value
  26. end
  27. end
  28. end
  29. end
  30. function nn.SpatialConvolutionMM:reset(stdv)
  31. msra_filler(self)
  32. end
  33. function nn.SpatialFullConvolution:reset(stdv)
  34. msra_filler(self)
  35. end
  36. function nn.SpatialDilatedConvolution:reset(stdv)
  37. identity_filler(self)
  38. end
  39. if cudnn and cudnn.SpatialConvolution then
  40. function cudnn.SpatialConvolution:reset(stdv)
  41. msra_filler(self)
  42. end
  43. function cudnn.SpatialFullConvolution:reset(stdv)
  44. msra_filler(self)
  45. end
  46. if cudnn.SpatialDilatedConvolution then
  47. function cudnn.SpatialDilatedConvolution:reset(stdv)
  48. identity_filler(self)
  49. end
  50. end
  51. end
  52. function nn.SpatialConvolutionMM:clearState()
  53. if self.gradWeight then
  54. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  55. end
  56. if self.gradBias then
  57. self.gradBias:resize(self.nOutputPlane):zero()
  58. end
  59. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  60. end
  61. function srcnn.channels(model)
  62. if model.w2nn_channels ~= nil then
  63. return model.w2nn_channels
  64. else
  65. return model:get(model:size() - 1).weight:size(1)
  66. end
  67. end
  68. function srcnn.backend(model)
  69. local conv = model:findModules("cudnn.SpatialConvolution")
  70. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  71. if #conv > 0 or #fullconv > 0 then
  72. return "cudnn"
  73. else
  74. return "cunn"
  75. end
  76. end
  77. function srcnn.color(model)
  78. local ch = srcnn.channels(model)
  79. if ch == 3 then
  80. return "rgb"
  81. else
  82. return "y"
  83. end
  84. end
  85. function srcnn.name(model)
  86. if model.w2nn_arch_name ~= nil then
  87. return model.w2nn_arch_name
  88. else
  89. local conv = model:findModules("nn.SpatialConvolutionMM")
  90. if #conv == 0 then
  91. conv = model:findModules("cudnn.SpatialConvolution")
  92. end
  93. if #conv == 7 then
  94. return "vgg_7"
  95. elseif #conv == 12 then
  96. return "vgg_12"
  97. else
  98. error("unsupported model")
  99. end
  100. end
  101. end
  102. function srcnn.offset_size(model)
  103. if model.w2nn_offset ~= nil then
  104. return model.w2nn_offset
  105. else
  106. local name = srcnn.name(model)
  107. if name:match("vgg_") then
  108. local conv = model:findModules("nn.SpatialConvolutionMM")
  109. if #conv == 0 then
  110. conv = model:findModules("cudnn.SpatialConvolution")
  111. end
  112. local offset = 0
  113. for i = 1, #conv do
  114. offset = offset + (conv[i].kW - 1) / 2
  115. end
  116. return math.floor(offset)
  117. else
  118. error("unsupported model")
  119. end
  120. end
  121. end
  122. function srcnn.scale_factor(model)
  123. if model.w2nn_scale_factor ~= nil then
  124. return model.w2nn_scale_factor
  125. else
  126. local name = srcnn.name(model)
  127. if name == "upconv_7" then
  128. return 2
  129. elseif name == "upconv_8_4x" then
  130. return 4
  131. else
  132. return 1
  133. end
  134. end
  135. end
  136. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  137. if backend == "cunn" then
  138. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  139. elseif backend == "cudnn" then
  140. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  141. else
  142. error("unsupported backend:" .. backend)
  143. end
  144. end
  145. srcnn.SpatialConvolution = SpatialConvolution
  146. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  147. if backend == "cunn" then
  148. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  149. elseif backend == "cudnn" then
  150. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  151. else
  152. error("unsupported backend:" .. backend)
  153. end
  154. end
  155. srcnn.SpatialFullConvolution = SpatialFullConvolution
  156. local function ReLU(backend)
  157. if backend == "cunn" then
  158. return nn.ReLU(true)
  159. elseif backend == "cudnn" then
  160. return cudnn.ReLU(true)
  161. else
  162. error("unsupported backend:" .. backend)
  163. end
  164. end
  165. srcnn.ReLU = ReLU
  166. local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
  167. if backend == "cunn" then
  168. return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  169. elseif backend == "cudnn" then
  170. return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  171. else
  172. error("unsupported backend:" .. backend)
  173. end
  174. end
  175. srcnn.SpatialMaxPooling = SpatialMaxPooling
  176. local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
  177. if backend == "cunn" then
  178. return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  179. elseif backend == "cudnn" then
  180. return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  181. else
  182. error("unsupported backend:" .. backend)
  183. end
  184. end
  185. srcnn.SpatialAveragePooling = SpatialAveragePooling
  186. local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  187. if backend == "cunn" then
  188. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  189. elseif backend == "cudnn" then
  190. if cudnn.SpatialDilatedConvolution then
  191. -- cudnn v 6
  192. return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  193. else
  194. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  195. end
  196. else
  197. error("unsupported backend:" .. backend)
  198. end
  199. end
  200. srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
  201. -- VGG style net(7 layers)
  202. function srcnn.vgg_7(backend, ch)
  203. local model = nn.Sequential()
  204. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  205. model:add(nn.LeakyReLU(0.1, true))
  206. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  207. model:add(nn.LeakyReLU(0.1, true))
  208. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  209. model:add(nn.LeakyReLU(0.1, true))
  210. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  211. model:add(nn.LeakyReLU(0.1, true))
  212. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  213. model:add(nn.LeakyReLU(0.1, true))
  214. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  215. model:add(nn.LeakyReLU(0.1, true))
  216. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  217. model:add(w2nn.InplaceClip01())
  218. model:add(nn.View(-1):setNumInputDims(3))
  219. model.w2nn_arch_name = "vgg_7"
  220. model.w2nn_offset = 7
  221. model.w2nn_scale_factor = 1
  222. model.w2nn_channels = ch
  223. --model:cuda()
  224. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  225. return model
  226. end
  227. -- VGG style net(12 layers)
  228. function srcnn.vgg_12(backend, ch)
  229. local model = nn.Sequential()
  230. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  231. model:add(nn.LeakyReLU(0.1, true))
  232. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  233. model:add(nn.LeakyReLU(0.1, true))
  234. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  235. model:add(nn.LeakyReLU(0.1, true))
  236. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  237. model:add(nn.LeakyReLU(0.1, true))
  238. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  239. model:add(nn.LeakyReLU(0.1, true))
  240. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  241. model:add(nn.LeakyReLU(0.1, true))
  242. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  243. model:add(nn.LeakyReLU(0.1, true))
  244. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  245. model:add(nn.LeakyReLU(0.1, true))
  246. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  247. model:add(nn.LeakyReLU(0.1, true))
  248. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  249. model:add(nn.LeakyReLU(0.1, true))
  250. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  251. model:add(nn.LeakyReLU(0.1, true))
  252. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  253. model:add(w2nn.InplaceClip01())
  254. model:add(nn.View(-1):setNumInputDims(3))
  255. model.w2nn_arch_name = "vgg_12"
  256. model.w2nn_offset = 12
  257. model.w2nn_scale_factor = 1
  258. model.w2nn_resize = false
  259. model.w2nn_channels = ch
  260. --model:cuda()
  261. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  262. return model
  263. end
  264. -- Dilated Convolution (7 layers)
  265. function srcnn.dilated_7(backend, ch)
  266. local model = nn.Sequential()
  267. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  268. model:add(nn.LeakyReLU(0.1, true))
  269. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  270. model:add(nn.LeakyReLU(0.1, true))
  271. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  272. model:add(nn.LeakyReLU(0.1, true))
  273. model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  274. model:add(nn.LeakyReLU(0.1, true))
  275. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
  276. model:add(nn.LeakyReLU(0.1, true))
  277. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  278. model:add(nn.LeakyReLU(0.1, true))
  279. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  280. model:add(w2nn.InplaceClip01())
  281. model:add(nn.View(-1):setNumInputDims(3))
  282. model.w2nn_arch_name = "dilated_7"
  283. model.w2nn_offset = 12
  284. model.w2nn_scale_factor = 1
  285. model.w2nn_resize = false
  286. model.w2nn_channels = ch
  287. --model:cuda()
  288. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  289. return model
  290. end
  291. -- Upconvolution
  292. function srcnn.upconv_7(backend, ch)
  293. local model = nn.Sequential()
  294. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  295. model:add(nn.LeakyReLU(0.1, true))
  296. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  297. model:add(nn.LeakyReLU(0.1, true))
  298. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  299. model:add(nn.LeakyReLU(0.1, true))
  300. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  301. model:add(nn.LeakyReLU(0.1, true))
  302. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  303. model:add(nn.LeakyReLU(0.1, true))
  304. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  305. model:add(nn.LeakyReLU(0.1, true))
  306. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  307. model:add(w2nn.InplaceClip01())
  308. model:add(nn.View(-1):setNumInputDims(3))
  309. model.w2nn_arch_name = "upconv_7"
  310. model.w2nn_offset = 14
  311. model.w2nn_scale_factor = 2
  312. model.w2nn_resize = true
  313. model.w2nn_channels = ch
  314. return model
  315. end
  316. -- large version of upconv_7
  317. -- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
  318. function srcnn.upconv_7l(backend, ch)
  319. local model = nn.Sequential()
  320. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  321. model:add(nn.LeakyReLU(0.1, true))
  322. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  323. model:add(nn.LeakyReLU(0.1, true))
  324. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  325. model:add(nn.LeakyReLU(0.1, true))
  326. model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
  327. model:add(nn.LeakyReLU(0.1, true))
  328. model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
  329. model:add(nn.LeakyReLU(0.1, true))
  330. model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
  331. model:add(nn.LeakyReLU(0.1, true))
  332. model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
  333. model:add(w2nn.InplaceClip01())
  334. model:add(nn.View(-1):setNumInputDims(3))
  335. model.w2nn_arch_name = "upconv_7l"
  336. model.w2nn_offset = 14
  337. model.w2nn_scale_factor = 2
  338. model.w2nn_resize = true
  339. model.w2nn_channels = ch
  340. --model:cuda()
  341. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  342. return model
  343. end
  344. -- layerwise linear blending with skip connections
  345. -- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l
  346. function srcnn.skiplb_7(backend, ch)
  347. local function skip(backend, i, o)
  348. local con = nn.Concat(2)
  349. local conv = nn.Sequential()
  350. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1))
  351. conv:add(nn.LeakyReLU(0.1, true))
  352. -- depth concat
  353. con:add(conv)
  354. con:add(nn.Identity()) -- skip
  355. return con
  356. end
  357. local model = nn.Sequential()
  358. model:add(skip(backend, ch, 16))
  359. model:add(skip(backend, 16+ch, 32))
  360. model:add(skip(backend, 32+16+ch, 64))
  361. model:add(skip(backend, 64+32+16+ch, 128))
  362. model:add(skip(backend, 128+64+32+16+ch, 128))
  363. model:add(skip(backend, 128+128+64+32+16+ch, 256))
  364. -- input of last layer = [all layerwise output(contains input layer)].flatten
  365. model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend
  366. model:add(w2nn.InplaceClip01())
  367. model:add(nn.View(-1):setNumInputDims(3))
  368. model.w2nn_arch_name = "skiplb_7"
  369. model.w2nn_offset = 14
  370. model.w2nn_scale_factor = 2
  371. model.w2nn_resize = true
  372. model.w2nn_channels = ch
  373. --model:cuda()
  374. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  375. return model
  376. end
  377. -- dilated convolution + deconvolution
  378. -- Note: This model is not better than upconv_7. Maybe becuase of under-fitting.
  379. function srcnn.dilated_upconv_7(backend, ch)
  380. local model = nn.Sequential()
  381. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  382. model:add(nn.LeakyReLU(0.1, true))
  383. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  384. model:add(nn.LeakyReLU(0.1, true))
  385. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  386. model:add(nn.LeakyReLU(0.1, true))
  387. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  388. model:add(nn.LeakyReLU(0.1, true))
  389. model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  390. model:add(nn.LeakyReLU(0.1, true))
  391. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  392. model:add(nn.LeakyReLU(0.1, true))
  393. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  394. model:add(w2nn.InplaceClip01())
  395. model:add(nn.View(-1):setNumInputDims(3))
  396. model.w2nn_arch_name = "dilated_upconv_7"
  397. model.w2nn_offset = 20
  398. model.w2nn_scale_factor = 2
  399. model.w2nn_resize = true
  400. model.w2nn_channels = ch
  401. --model:cuda()
  402. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  403. return model
  404. end
  405. -- ref: https://arxiv.org/abs/1609.04802
  406. -- note: no batch-norm, no zero-paading
  407. function srcnn.srresnet_2x(backend, ch)
  408. local function resblock(backend)
  409. local seq = nn.Sequential()
  410. local con = nn.ConcatTable()
  411. local conv = nn.Sequential()
  412. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  413. conv:add(ReLU(backend))
  414. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  415. conv:add(ReLU(backend))
  416. con:add(conv)
  417. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  418. seq:add(con)
  419. seq:add(nn.CAddTable())
  420. return seq
  421. end
  422. local model = nn.Sequential()
  423. --model:add(skip(backend, ch, 64 - ch))
  424. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  425. model:add(nn.LeakyReLU(0.1, true))
  426. model:add(resblock(backend))
  427. model:add(resblock(backend))
  428. model:add(resblock(backend))
  429. model:add(resblock(backend))
  430. model:add(resblock(backend))
  431. model:add(resblock(backend))
  432. model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2))
  433. model:add(ReLU(backend))
  434. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  435. model:add(w2nn.InplaceClip01())
  436. --model:add(nn.View(-1):setNumInputDims(3))
  437. model.w2nn_arch_name = "srresnet_2x"
  438. model.w2nn_offset = 28
  439. model.w2nn_scale_factor = 2
  440. model.w2nn_resize = true
  441. model.w2nn_channels = ch
  442. --model:cuda()
  443. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  444. return model
  445. end
  446. -- large version of srresnet_2x. It's current best model but slow.
  447. function srcnn.resnet_14l(backend, ch)
  448. local function resblock(backend, i, o)
  449. local seq = nn.Sequential()
  450. local con = nn.ConcatTable()
  451. local conv = nn.Sequential()
  452. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  453. conv:add(nn.LeakyReLU(0.1, true))
  454. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  455. conv:add(nn.LeakyReLU(0.1, true))
  456. con:add(conv)
  457. if i == o then
  458. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  459. else
  460. local seq = nn.Sequential()
  461. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  462. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  463. con:add(seq)
  464. end
  465. seq:add(con)
  466. seq:add(nn.CAddTable())
  467. return seq
  468. end
  469. local model = nn.Sequential()
  470. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  471. model:add(nn.LeakyReLU(0.1, true))
  472. model:add(resblock(backend, 32, 64))
  473. model:add(resblock(backend, 64, 64))
  474. model:add(resblock(backend, 64, 128))
  475. model:add(resblock(backend, 128, 128))
  476. model:add(resblock(backend, 128, 256))
  477. model:add(resblock(backend, 256, 256))
  478. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  479. model:add(w2nn.InplaceClip01())
  480. model:add(nn.View(-1):setNumInputDims(3))
  481. model.w2nn_arch_name = "resnet_14l"
  482. model.w2nn_offset = 28
  483. model.w2nn_scale_factor = 2
  484. model.w2nn_resize = true
  485. model.w2nn_channels = ch
  486. --model:cuda()
  487. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  488. return model
  489. end
  490. -- for segmentation
  491. function srcnn.fcn_v1(backend, ch)
  492. -- input_size = 120
  493. local model = nn.Sequential()
  494. --i = 120
  495. --model:cuda()
  496. --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
  497. model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
  498. model:add(nn.LeakyReLU(0.1, true))
  499. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  500. model:add(nn.LeakyReLU(0.1, true))
  501. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  502. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  503. model:add(nn.LeakyReLU(0.1, true))
  504. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  505. model:add(nn.LeakyReLU(0.1, true))
  506. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  507. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  508. model:add(nn.LeakyReLU(0.1, true))
  509. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  510. model:add(nn.LeakyReLU(0.1, true))
  511. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  512. model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
  513. model:add(nn.LeakyReLU(0.1, true))
  514. model:add(nn.Dropout(0.5, false, true))
  515. model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
  516. model:add(nn.LeakyReLU(0.1, true))
  517. model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
  518. model:add(nn.LeakyReLU(0.1, true))
  519. model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
  520. model:add(nn.LeakyReLU(0.1, true))
  521. model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
  522. model:add(nn.LeakyReLU(0.1, true))
  523. model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
  524. model:add(nn.LeakyReLU(0.1, true))
  525. model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
  526. model:add(w2nn.InplaceClip01())
  527. model:add(nn.View(-1):setNumInputDims(3))
  528. model.w2nn_arch_name = "fcn_v1"
  529. model.w2nn_offset = 36
  530. model.w2nn_scale_factor = 1
  531. model.w2nn_channels = ch
  532. model.w2nn_input_size = 120
  533. --model.w2nn_gcn = true
  534. return model
  535. end
  536. function srcnn.cupconv_14(backend, ch)
  537. local function skip(backend, n_input, n_output, pad)
  538. local con = nn.ConcatTable()
  539. local conv = nn.Sequential()
  540. local depad = nn.Sequential()
  541. conv:add(nn.SelectTable(1))
  542. conv:add(SpatialConvolution(backend, n_input, n_output, 3, 3, 1, 1, 0, 0))
  543. conv:add(nn.LeakyReLU(0.1, true))
  544. con:add(conv)
  545. con:add(nn.Identity())
  546. return con
  547. end
  548. local function concat(backend, n, ch, n_middle)
  549. local con = nn.ConcatTable()
  550. for i = 1, n do
  551. local pad = i - 1
  552. if i == 1 then
  553. con:add(nn.Sequential():add(nn.SelectTable(i)))
  554. else
  555. local seq = nn.Sequential()
  556. seq:add(nn.SelectTable(i))
  557. if pad > 0 then
  558. seq:add(nn.SpatialZeroPadding(-pad, -pad, -pad, -pad))
  559. end
  560. if i == n then
  561. --seq:add(SpatialConvolution(backend, ch, n_middle, 1, 1, 1, 1, 0, 0))
  562. else
  563. seq:add(w2nn.GradWeight(0.025))
  564. seq:add(SpatialConvolution(backend, n_middle, n_middle, 1, 1, 1, 1, 0, 0))
  565. end
  566. seq:add(nn.LeakyReLU(0.1, true))
  567. con:add(seq)
  568. end
  569. end
  570. return nn.Sequential():add(con):add(nn.JoinTable(2))
  571. end
  572. local model = nn.Sequential()
  573. local m = 64
  574. local n = 14
  575. model:add(nn.ConcatTable():add(nn.Identity()))
  576. for i = 1, n - 1 do
  577. if i == 1 then
  578. model:add(skip(backend, ch, m))
  579. else
  580. model:add(skip(backend, m, m))
  581. end
  582. end
  583. model:add(nn.FlattenTable())
  584. model:add(concat(backend, n, ch, m))
  585. model:add(SpatialFullConvolution(backend, m * (n - 1) + 3, ch, 4, 4, 2, 2, 3, 3):noBias())
  586. model:add(w2nn.InplaceClip01())
  587. model:add(nn.View(-1):setNumInputDims(3))
  588. model.w2nn_arch_name = "cupconv_14"
  589. model.w2nn_offset = 28
  590. model.w2nn_scale_factor = 2
  591. model.w2nn_channels = ch
  592. model.w2nn_resize = true
  593. return model
  594. end
  595. function srcnn.upconv_refine(backend, ch)
  596. local function block(backend, ch)
  597. local seq = nn.Sequential()
  598. local con = nn.ConcatTable()
  599. local res = nn.Sequential()
  600. local base = nn.Sequential()
  601. local refine = nn.Sequential()
  602. local aux_con = nn.ConcatTable()
  603. res:add(w2nn.GradWeight(0.1))
  604. res:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  605. res:add(nn.LeakyReLU(0.1, true))
  606. res:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  607. res:add(nn.LeakyReLU(0.1, true))
  608. res:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  609. res:add(nn.LeakyReLU(0.1, true))
  610. res:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0):noBias())
  611. res:add(w2nn.InplaceClip01())
  612. res:add(nn.MulConstant(0.5))
  613. con:add(res)
  614. con:add(nn.Sequential():add(nn.SpatialZeroPadding(-4, -4, -4, -4)):add(nn.MulConstant(0.5)))
  615. -- main output
  616. refine:add(nn.CAddTable()) -- averaging
  617. refine:add(nn.View(-1):setNumInputDims(3))
  618. -- aux output
  619. base:add(nn.SelectTable(2))
  620. base:add(nn.MulConstant(2)) -- revert mul 0.5
  621. base:add(nn.View(-1):setNumInputDims(3))
  622. aux_con:add(refine)
  623. aux_con:add(base)
  624. seq:add(con)
  625. seq:add(aux_con)
  626. seq:add(w2nn.AuxiliaryLossTable(1))
  627. return seq
  628. end
  629. local model = nn.Sequential()
  630. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  631. model:add(nn.LeakyReLU(0.1, true))
  632. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  633. model:add(nn.LeakyReLU(0.1, true))
  634. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  635. model:add(nn.LeakyReLU(0.1, true))
  636. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  637. model:add(nn.LeakyReLU(0.1, true))
  638. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  639. model:add(nn.LeakyReLU(0.1, true))
  640. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  641. model:add(nn.LeakyReLU(0.1, true))
  642. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  643. model:add(w2nn.InplaceClip01())
  644. model:add(block(backend, ch))
  645. model.w2nn_arch_name = "upconv_refine"
  646. model.w2nn_offset = 18
  647. model.w2nn_scale_factor = 2
  648. model.w2nn_resize = true
  649. model.w2nn_channels = ch
  650. return model
  651. end
  652. -- cascade u-net
  653. function srcnn.cunet_v1(backend, ch)
  654. function unet_branch(insert, backend, n_input, n_output, depad)
  655. local block = nn.Sequential()
  656. local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
  657. --block:add(w2nn.Print())
  658. block:add(pooling)
  659. block:add(insert)
  660. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  661. local parallel = nn.ConcatTable(2)
  662. parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  663. parallel:add(block)
  664. local model = nn.Sequential()
  665. model:add(parallel)
  666. model:add(nn.JoinTable(2))
  667. return model
  668. end
  669. function unet_conv(n_input, n_middle, n_output)
  670. local model = nn.Sequential()
  671. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  672. model:add(nn.LeakyReLU(0.1, true))
  673. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  674. return model
  675. end
  676. function unet(backend, ch, deconv)
  677. --
  678. local block1 = unet_conv(128, 256, 128)
  679. local block2 = nn.Sequential()
  680. block2:add(unet_conv(32, 64, 128))
  681. block2:add(unet_branch(block1, backend, 128, 128, 4))
  682. block2:add(unet_conv(128*2, 64, 32))
  683. local model = nn.Sequential()
  684. model:add(unet_conv(ch, 32, 32))
  685. model:add(unet_branch(block2, backend, 32, 32, 16))
  686. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  687. model:add(nn.LeakyReLU(0.1))
  688. if deconv then
  689. model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 3, 3))
  690. else
  691. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  692. end
  693. return model
  694. end
  695. local model = nn.Sequential()
  696. local con = nn.ConcatTable()
  697. local aux_con = nn.ConcatTable()
  698. model:add(unet(backend, ch, true))
  699. con:add(unet(backend, ch, false))
  700. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  701. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  702. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  703. model:add(con)
  704. model:add(aux_con)
  705. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  706. model.w2nn_arch_name = "cunet_v1"
  707. model.w2nn_offset = 60
  708. model.w2nn_scale_factor = 2
  709. model.w2nn_channels = ch
  710. model.w2nn_resize = true
  711. -- 72, 128, 256 are valid
  712. --model.w2nn_input_size = 128
  713. return model
  714. end
  715. -- cascade u-net
  716. function srcnn.cunet_v2(backend, ch)
  717. function unet_branch(insert, backend, n_input, n_output, depad)
  718. local block = nn.Sequential()
  719. local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
  720. --block:add(w2nn.Print())
  721. block:add(pooling)
  722. block:add(insert)
  723. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  724. local parallel = nn.ConcatTable(2)
  725. parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  726. parallel:add(block)
  727. local model = nn.Sequential()
  728. model:add(parallel)
  729. model:add(nn.CAddTable(2))
  730. return model
  731. end
  732. function unet_conv(n_input, n_middle, n_output)
  733. local model = nn.Sequential()
  734. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  735. model:add(nn.LeakyReLU(0.1, true))
  736. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  737. return model
  738. end
  739. -- res unet
  740. function unet(backend, ch, deconv)
  741. local block1 = unet_conv(128, 256, 128)
  742. local block2 = nn.Sequential()
  743. block2:add(unet_conv(64, 128, 128))
  744. block2:add(unet_branch(block1, backend, 128, 128, 4))
  745. block2:add(unet_conv(128, 128, 64))
  746. local model = nn.Sequential()
  747. model:add(nn.SpatialZeroPadding(-1, -1, -1, -1))
  748. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  749. model:add(unet_branch(block2, backend, 64, 64, 16))
  750. if deconv then
  751. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  752. model:add(nn.LeakyReLU(0.1))
  753. model:add(SpatialFullConvolution(backend, 128, 64, 4, 4, 2, 2, 3, 3))
  754. else
  755. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  756. end
  757. return model
  758. end
  759. local model = nn.Sequential()
  760. local con = nn.ConcatTable()
  761. local aux_con = nn.ConcatTable()
  762. model:add(unet(backend, ch, true))
  763. con:add(unet(backend, 64, false))
  764. con:add(nn.SpatialZeroPadding(-19, -19, -19, -19))
  765. model:add(con)
  766. model:add(nn.CAddTable())
  767. model:add(nn.LeakyReLU(0.1, true))
  768. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  769. model.w2nn_arch_name = "cunet_v2"
  770. model.w2nn_offset = 60
  771. model.w2nn_scale_factor = 2
  772. model.w2nn_channels = ch
  773. model.w2nn_resize = true
  774. -- 72, 128, 256 are valid
  775. --model.w2nn_input_size = 128
  776. return model
  777. end
  778. -- cascade u-net
  779. function srcnn.cunet_v3(backend, ch)
  780. function unet_branch(insert, backend, n_input, n_output, depad)
  781. local block = nn.Sequential()
  782. local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
  783. --block:add(w2nn.Print())
  784. block:add(pooling)
  785. block:add(insert)
  786. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  787. local parallel = nn.ConcatTable(2)
  788. parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  789. parallel:add(block)
  790. local model = nn.Sequential()
  791. model:add(parallel)
  792. model:add(nn.CAddTable())
  793. return model
  794. end
  795. function unet_conv(n_input, n_middle, n_output)
  796. local model = nn.Sequential()
  797. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  798. model:add(nn.LeakyReLU(0.1, true))
  799. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  800. model:add(nn.LeakyReLU(0.1, true))
  801. return model
  802. end
  803. function unet(backend, ch, deconv)
  804. local block1 = unet_conv(128, 256, 128)
  805. local block2 = nn.Sequential()
  806. block2:add(unet_conv(64, 64, 128))
  807. block2:add(unet_branch(block1, backend, 128, 128, 4))
  808. block2:add(unet_conv(128, 64, 64))
  809. local model = nn.Sequential()
  810. model:add(unet_conv(ch, 32, 64))
  811. model:add(unet_branch(block2, backend, 64, 64, 16))
  812. if deconv then
  813. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  814. model:add(nn.LeakyReLU(0.1))
  815. model:add(SpatialFullConvolution(backend, 128, 64, 4, 4, 2, 2, 3, 3))
  816. end
  817. return model
  818. end
  819. local model = nn.Sequential()
  820. local con = nn.ConcatTable()
  821. model:add(unet(backend, ch, true))
  822. model:add(nn.ConcatTable():add(unet(backend, 64, false)):add(nn.SpatialZeroPadding(-18, -18, -18, -18)))
  823. model:add(nn.CAddTable())
  824. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  825. model:add(nn.LeakyReLU())
  826. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  827. model:add(w2nn.InplaceClip01())
  828. model.w2nn_arch_name = "cunet_v3"
  829. model.w2nn_offset = 60
  830. model.w2nn_scale_factor = 2
  831. model.w2nn_channels = ch
  832. model.w2nn_resize = true
  833. -- 72, 128, 256 are valid
  834. --model.w2nn_input_size = 128
  835. return model
  836. end
  837. -- cascade u-net
  838. function srcnn.cunet_v4(backend, ch)
  839. function upconv_3(backend, n_input, n_output)
  840. local model = nn.Sequential()
  841. model:add(SpatialConvolution(backend, n_input, 32, 3, 3, 1, 1, 0, 0))
  842. model:add(nn.LeakyReLU(0.1, true))
  843. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  844. model:add(nn.LeakyReLU(0.1, true))
  845. model:add(SpatialFullConvolution(backend, 32, n_output, 4, 4, 2, 2, 3, 3):noBias())
  846. return model
  847. end
  848. function unet_branch(insert, backend, n_input, n_output, depad)
  849. local block = nn.Sequential()
  850. local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
  851. --block:add(w2nn.Print())
  852. block:add(pooling)
  853. block:add(insert)
  854. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  855. local parallel = nn.ConcatTable(2)
  856. parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  857. parallel:add(block)
  858. local model = nn.Sequential()
  859. model:add(parallel)
  860. model:add(nn.CAddTable())
  861. return model
  862. end
  863. function unet_conv(n_input, n_middle, n_output)
  864. local model = nn.Sequential()
  865. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  866. model:add(nn.LeakyReLU(0.1, true))
  867. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  868. model:add(nn.LeakyReLU(0.1, true))
  869. return model
  870. end
  871. function unet(backend, ch)
  872. local block1 = unet_conv(128, 256, 128)
  873. local block2 = nn.Sequential()
  874. block2:add(unet_conv(64, 64, 128))
  875. block2:add(unet_branch(block1, backend, 128, 128, 4))
  876. block2:add(unet_conv(128, 64, 64))
  877. local model = nn.Sequential()
  878. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  879. model:add(nn.LeakyReLU(0.1, true))
  880. model:add(unet_branch(block2, backend, 64, 64, 16))
  881. return model
  882. end
  883. local model = nn.Sequential()
  884. local con = nn.ConcatTable()
  885. local aux_con = nn.ConcatTable()
  886. model:add(upconv_3(backend, ch, 64))
  887. con:add(unet(backend, 32))
  888. --con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  889. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  890. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single output
  891. model:add(conn)
  892. model:add(nn.CAddTable())
  893. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  894. model:add(nn.LeakyReLU())
  895. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  896. model:add(w2nn.InplaceClip01())
  897. model.w2nn_arch_name = "cunet_v3"
  898. model.w2nn_offset = 60
  899. model.w2nn_scale_factor = 2
  900. model.w2nn_channels = ch
  901. model.w2nn_resize = true
  902. -- 72, 128, 256 are valid
  903. --model.w2nn_input_size = 128
  904. return model
  905. end
  906. function srcnn.cunet_v6(backend, ch)
  907. function unet_branch(insert, backend, n_input, n_output, depad)
  908. local block = nn.Sequential()
  909. local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
  910. --block:add(w2nn.Print())
  911. block:add(pooling)
  912. block:add(insert)
  913. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  914. local parallel = nn.ConcatTable(2)
  915. parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  916. parallel:add(block)
  917. local model = nn.Sequential()
  918. model:add(parallel)
  919. model:add(nn.CAddTable())
  920. return model
  921. end
  922. function unet_conv(n_input, n_middle, n_output, se)
  923. local model = nn.Sequential()
  924. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  925. model:add(nn.LeakyReLU(0.1, true))
  926. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  927. model:add(nn.LeakyReLU(0.1, true))
  928. if se then
  929. -- Squeeze and Excitation Networks
  930. local con = nn.ConcatTable(2)
  931. local attention = nn.Sequential()
  932. attention:add(nn.SpatialAdaptiveAveragePooling(1, 1)) -- global average pooling
  933. attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / 4), 1, 1, 1, 1, 0, 0))
  934. attention:add(nn.ReLU(true))
  935. attention:add(SpatialConvolution(backend, math.floor(n_output / 4), n_output, 1, 1, 1, 1, 0, 0))
  936. attention:add(nn.Sigmoid(true))
  937. con:add(nn.Identity())
  938. con:add(attention)
  939. model:add(con)
  940. model:add(w2nn.ScaleTable())
  941. end
  942. return model
  943. end
  944. -- Residual U-Net
  945. function unet(backend, ch, deconv)
  946. local block1 = unet_conv(128, 256, 128, true)
  947. local block2 = nn.Sequential()
  948. block2:add(unet_conv(64, 64, 128, true))
  949. block2:add(unet_branch(block1, backend, 128, 128, 4))
  950. block2:add(unet_conv(128, 64, 64, true))
  951. local model = nn.Sequential()
  952. model:add(unet_conv(ch, 32, 64, false))
  953. model:add(unet_branch(block2, backend, 64, 64, 16))
  954. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  955. model:add(nn.LeakyReLU(0.1))
  956. if deconv then
  957. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  958. else
  959. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  960. end
  961. return model
  962. end
  963. local model = nn.Sequential()
  964. local con = nn.ConcatTable()
  965. local aux_con = nn.ConcatTable()
  966. model:add(unet(backend, ch, true))
  967. con:add(unet(backend, ch, false))
  968. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  969. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  970. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  971. model:add(con)
  972. model:add(aux_con)
  973. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  974. model.w2nn_arch_name = "cunet_v6"
  975. model.w2nn_offset = 60
  976. model.w2nn_scale_factor = 2
  977. model.w2nn_channels = ch
  978. model.w2nn_resize = true
  979. -- 72, 128, 256 are valid
  980. --model.w2nn_input_size = 128
  981. return model
  982. end
  983. function srcnn.prog_net(backend, ch)
  984. function base_upscaler(backend, ch)
  985. local model = nn.Sequential()
  986. model:add(nn.SpatialZeroPadding(-11, -11, -11, -11))
  987. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  988. model:add(nn.LeakyReLU(0.1, true))
  989. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  990. model:add(nn.LeakyReLU(0.1, true))
  991. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3):noBias())
  992. model:add(w2nn.InplaceClip01())
  993. return model
  994. end
  995. function block(backend, input, output)
  996. local con = nn.ConcatTable()
  997. local conv = nn.Sequential()
  998. local dil = nn.Sequential()
  999. local b = nn.Sequential()
  1000. conv:add(SpatialConvolution(backend, input, output, 3, 3, 1, 1, 0, 0))
  1001. conv:add(nn.SpatialZeroPadding(-5, -5, -5, -5))
  1002. dil:add(SpatialDilatedConvolution(backend, input, output, 3, 3, 1, 1, 0, 0, 2, 2))
  1003. dil:add(nn.LeakyReLU(0.1, true))
  1004. dil:add(SpatialDilatedConvolution(backend, output, output, 3, 3, 1, 1, 0, 0, 4, 4))
  1005. con:add(conv)
  1006. con:add(dil)
  1007. b:add(con)
  1008. b:add(nn.CAddTable())
  1009. b:add(nn.LeakyReLU(0.1, true))
  1010. return b
  1011. end
  1012. function texture_upscaler(backend, ch)
  1013. local model = nn.Sequential()
  1014. model:add(w2nn.EdgeFilter(ch))
  1015. model:add(SpatialConvolution(backend, ch * 8, 32, 1, 1, 1, 1, 0, 0))
  1016. model:add(nn.LeakyReLU(0.1, true))
  1017. model:add(block(backend, 32, 128))
  1018. model:add(block(backend, 128, 256))
  1019. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  1020. return model
  1021. end
  1022. local model = nn.Sequential()
  1023. local con = nn.ConcatTable()
  1024. local aux = nn.ConcatTable()
  1025. con:add(base_upscaler(backend, ch))
  1026. con:add(texture_upscaler(backend, ch))
  1027. aux:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01()):add(nn.View(-1):setNumInputDims(3)))
  1028. aux:add(nn.Sequential():add(nn.SelectTable(1)):add(nn.View(-1):setNumInputDims(3)))
  1029. model:add(con)
  1030. model:add(aux)
  1031. model:add(w2nn.AuxiliaryLossTable(1))
  1032. model.w2nn_arch_name = "prog_net"
  1033. model.w2nn_offset = 28
  1034. model.w2nn_scale_factor = 2
  1035. model.w2nn_channels = ch
  1036. model.w2nn_resize = true
  1037. return model
  1038. end
  1039. function srcnn.create(model_name, backend, color)
  1040. model_name = model_name or "vgg_7"
  1041. backend = backend or "cunn"
  1042. color = color or "rgb"
  1043. local ch = 3
  1044. if color == "rgb" then
  1045. ch = 3
  1046. elseif color == "y" then
  1047. ch = 1
  1048. else
  1049. error("unsupported color: " .. color)
  1050. end
  1051. if srcnn[model_name] then
  1052. local model = srcnn[model_name](backend, ch)
  1053. assert(model.w2nn_offset % model.w2nn_scale_factor == 0)
  1054. return model
  1055. else
  1056. error("unsupported model_name: " .. model_name)
  1057. end
  1058. end
  1059. --[[
  1060. local model = srcnn.fcn_v1("cunn", 3):cuda()
  1061. print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
  1062. print(model)
  1063. local model = srcnn.unet_refine("cunn", 3):cuda()
  1064. print(model)
  1065. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
  1066. local model = srcnn.cupconv_14("cunn", 3):cuda()
  1067. print(model)
  1068. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
  1069. os.exit()
  1070. local model = srcnn.cupconv_14("cunn", 3):cuda()
  1071. print(model)
  1072. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
  1073. os.exit()
  1074. local model = srcnn.upconv_refine("cunn", 3):cuda()
  1075. print(model)
  1076. model:training()
  1077. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()))
  1078. os.exit()
  1079. local model = srcnn.nw2("cunn", 3):cuda()
  1080. print(model)
  1081. model:training()
  1082. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()))
  1083. os.exit()
  1084. local model = srcnn.prog_net("cunn", 3):cuda()
  1085. print(model)
  1086. model:training()
  1087. print(model:forward(torch.Tensor(1, 3, 128, 128):zero():cuda()))
  1088. os.exit()
  1089. local model = srcnn.double_unet("cunn", 3):cuda()
  1090. print(model)
  1091. model:training()
  1092. print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()))
  1093. os.exit()
  1094. local model = srcnn.cunet_v3("cunn", 3):cuda()
  1095. print(model)
  1096. model:training()
  1097. print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()):size())
  1098. os.exit()
  1099. local model = srcnn.cunet_v6("cunn", 3):cuda()
  1100. print(model)
  1101. model:training()
  1102. print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()))
  1103. os.exit()
  1104. --]]
  1105. return srcnn