srcnn.lua 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. local function msra_filler(mod)
  6. local fin = mod.kW * mod.kH * mod.nInputPlane
  7. local fout = mod.kW * mod.kH * mod.nOutputPlane
  8. stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
  9. mod.weight:normal(0, stdv)
  10. mod.bias:zero()
  11. end
  12. local function identity_filler(mod)
  13. assert(mod.nInputPlane <= mod.nOutputPlane)
  14. mod.weight:normal(0, 0.01)
  15. mod.bias:zero()
  16. local num_groups = mod.nInputPlane -- fixed
  17. local filler_value = num_groups / mod.nOutputPlane
  18. local in_group_size = math.floor(mod.nInputPlane / num_groups)
  19. local out_group_size = math.floor(mod.nOutputPlane / num_groups)
  20. local x = math.floor(mod.kW / 2)
  21. local y = math.floor(mod.kH / 2)
  22. for i = 0, num_groups - 1 do
  23. for j = i * out_group_size, (i + 1) * out_group_size - 1 do
  24. for k = i * in_group_size, (i + 1) * in_group_size - 1 do
  25. mod.weight[j+1][k+1][y+1][x+1] = filler_value
  26. end
  27. end
  28. end
  29. end
  30. function nn.SpatialConvolutionMM:reset(stdv)
  31. msra_filler(self)
  32. end
  33. function nn.SpatialFullConvolution:reset(stdv)
  34. msra_filler(self)
  35. end
  36. function nn.SpatialDilatedConvolution:reset(stdv)
  37. identity_filler(self)
  38. end
  39. if cudnn and cudnn.SpatialConvolution then
  40. function cudnn.SpatialConvolution:reset(stdv)
  41. msra_filler(self)
  42. end
  43. function cudnn.SpatialFullConvolution:reset(stdv)
  44. msra_filler(self)
  45. end
  46. if cudnn.SpatialDilatedConvolution then
  47. function cudnn.SpatialDilatedConvolution:reset(stdv)
  48. identity_filler(self)
  49. end
  50. end
  51. end
  52. function nn.SpatialConvolutionMM:clearState()
  53. if self.gradWeight then
  54. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  55. end
  56. if self.gradBias then
  57. self.gradBias:resize(self.nOutputPlane):zero()
  58. end
  59. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  60. end
  61. function srcnn.channels(model)
  62. if model.w2nn_channels ~= nil then
  63. return model.w2nn_channels
  64. else
  65. return model:get(model:size() - 1).weight:size(1)
  66. end
  67. end
  68. function srcnn.backend(model)
  69. local conv = model:findModules("cudnn.SpatialConvolution")
  70. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  71. if #conv > 0 or #fullconv > 0 then
  72. return "cudnn"
  73. else
  74. return "cunn"
  75. end
  76. end
  77. function srcnn.color(model)
  78. local ch = srcnn.channels(model)
  79. if ch == 3 then
  80. return "rgb"
  81. else
  82. return "y"
  83. end
  84. end
  85. function srcnn.name(model)
  86. if model.w2nn_arch_name ~= nil then
  87. return model.w2nn_arch_name
  88. else
  89. local conv = model:findModules("nn.SpatialConvolutionMM")
  90. if #conv == 0 then
  91. conv = model:findModules("cudnn.SpatialConvolution")
  92. end
  93. if #conv == 7 then
  94. return "vgg_7"
  95. elseif #conv == 12 then
  96. return "vgg_12"
  97. else
  98. error("unsupported model")
  99. end
  100. end
  101. end
  102. function srcnn.offset_size(model)
  103. if model.w2nn_offset ~= nil then
  104. return model.w2nn_offset
  105. else
  106. local name = srcnn.name(model)
  107. if name:match("vgg_") then
  108. local conv = model:findModules("nn.SpatialConvolutionMM")
  109. if #conv == 0 then
  110. conv = model:findModules("cudnn.SpatialConvolution")
  111. end
  112. local offset = 0
  113. for i = 1, #conv do
  114. offset = offset + (conv[i].kW - 1) / 2
  115. end
  116. return math.floor(offset)
  117. else
  118. error("unsupported model")
  119. end
  120. end
  121. end
  122. function srcnn.scale_factor(model)
  123. if model.w2nn_scale_factor ~= nil then
  124. return model.w2nn_scale_factor
  125. else
  126. local name = srcnn.name(model)
  127. if name == "upconv_7" then
  128. return 2
  129. elseif name == "upconv_8_4x" then
  130. return 4
  131. else
  132. return 1
  133. end
  134. end
  135. end
  136. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  137. if backend == "cunn" then
  138. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  139. elseif backend == "cudnn" then
  140. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  141. else
  142. error("unsupported backend:" .. backend)
  143. end
  144. end
  145. srcnn.SpatialConvolution = SpatialConvolution
  146. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  147. if backend == "cunn" then
  148. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  149. elseif backend == "cudnn" then
  150. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  151. else
  152. error("unsupported backend:" .. backend)
  153. end
  154. end
  155. srcnn.SpatialFullConvolution = SpatialFullConvolution
  156. local function ReLU(backend)
  157. if backend == "cunn" then
  158. return nn.ReLU(true)
  159. elseif backend == "cudnn" then
  160. return cudnn.ReLU(true)
  161. else
  162. error("unsupported backend:" .. backend)
  163. end
  164. end
  165. srcnn.ReLU = ReLU
  166. local function Sigmoid(backend)
  167. if backend == "cunn" then
  168. return nn.Sigmoid(true)
  169. elseif backend == "cudnn" then
  170. return cudnn.Sigmoid(true)
  171. else
  172. error("unsupported backend:" .. backend)
  173. end
  174. end
  175. srcnn.ReLU = ReLU
  176. local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
  177. if backend == "cunn" then
  178. return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  179. elseif backend == "cudnn" then
  180. return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  181. else
  182. error("unsupported backend:" .. backend)
  183. end
  184. end
  185. srcnn.SpatialMaxPooling = SpatialMaxPooling
  186. local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
  187. if backend == "cunn" then
  188. return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  189. elseif backend == "cudnn" then
  190. return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  191. else
  192. error("unsupported backend:" .. backend)
  193. end
  194. end
  195. srcnn.SpatialAveragePooling = SpatialAveragePooling
  196. local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  197. if backend == "cunn" then
  198. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  199. elseif backend == "cudnn" then
  200. if cudnn.SpatialDilatedConvolution then
  201. -- cudnn v 6
  202. return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  203. else
  204. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  205. end
  206. else
  207. error("unsupported backend:" .. backend)
  208. end
  209. end
  210. srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
  211. -- VGG style net(7 layers)
  212. function srcnn.vgg_7(backend, ch)
  213. local model = nn.Sequential()
  214. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  215. model:add(nn.LeakyReLU(0.1, true))
  216. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  217. model:add(nn.LeakyReLU(0.1, true))
  218. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  219. model:add(nn.LeakyReLU(0.1, true))
  220. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  221. model:add(nn.LeakyReLU(0.1, true))
  222. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  223. model:add(nn.LeakyReLU(0.1, true))
  224. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  225. model:add(nn.LeakyReLU(0.1, true))
  226. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  227. model:add(w2nn.InplaceClip01())
  228. model:add(nn.View(-1):setNumInputDims(3))
  229. model.w2nn_arch_name = "vgg_7"
  230. model.w2nn_offset = 7
  231. model.w2nn_scale_factor = 1
  232. model.w2nn_channels = ch
  233. --model:cuda()
  234. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  235. return model
  236. end
  237. -- VGG style net(12 layers)
  238. function srcnn.vgg_12(backend, ch)
  239. local model = nn.Sequential()
  240. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  241. model:add(nn.LeakyReLU(0.1, true))
  242. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  243. model:add(nn.LeakyReLU(0.1, true))
  244. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  245. model:add(nn.LeakyReLU(0.1, true))
  246. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  247. model:add(nn.LeakyReLU(0.1, true))
  248. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  249. model:add(nn.LeakyReLU(0.1, true))
  250. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  251. model:add(nn.LeakyReLU(0.1, true))
  252. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  253. model:add(nn.LeakyReLU(0.1, true))
  254. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  255. model:add(nn.LeakyReLU(0.1, true))
  256. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  257. model:add(nn.LeakyReLU(0.1, true))
  258. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  259. model:add(nn.LeakyReLU(0.1, true))
  260. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  261. model:add(nn.LeakyReLU(0.1, true))
  262. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  263. model:add(w2nn.InplaceClip01())
  264. model:add(nn.View(-1):setNumInputDims(3))
  265. model.w2nn_arch_name = "vgg_12"
  266. model.w2nn_offset = 12
  267. model.w2nn_scale_factor = 1
  268. model.w2nn_resize = false
  269. model.w2nn_channels = ch
  270. --model:cuda()
  271. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  272. return model
  273. end
  274. -- Dilated Convolution (7 layers)
  275. function srcnn.dilated_7(backend, ch)
  276. local model = nn.Sequential()
  277. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  278. model:add(nn.LeakyReLU(0.1, true))
  279. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  280. model:add(nn.LeakyReLU(0.1, true))
  281. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  282. model:add(nn.LeakyReLU(0.1, true))
  283. model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  284. model:add(nn.LeakyReLU(0.1, true))
  285. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
  286. model:add(nn.LeakyReLU(0.1, true))
  287. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  288. model:add(nn.LeakyReLU(0.1, true))
  289. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  290. model:add(w2nn.InplaceClip01())
  291. model:add(nn.View(-1):setNumInputDims(3))
  292. model.w2nn_arch_name = "dilated_7"
  293. model.w2nn_offset = 12
  294. model.w2nn_scale_factor = 1
  295. model.w2nn_resize = false
  296. model.w2nn_channels = ch
  297. --model:cuda()
  298. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  299. return model
  300. end
  301. -- Upconvolution
  302. function srcnn.upconv_7(backend, ch)
  303. local model = nn.Sequential()
  304. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  305. model:add(nn.LeakyReLU(0.1, true))
  306. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  307. model:add(nn.LeakyReLU(0.1, true))
  308. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  309. model:add(nn.LeakyReLU(0.1, true))
  310. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  311. model:add(nn.LeakyReLU(0.1, true))
  312. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  313. model:add(nn.LeakyReLU(0.1, true))
  314. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  315. model:add(nn.LeakyReLU(0.1, true))
  316. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  317. model:add(w2nn.InplaceClip01())
  318. model:add(nn.View(-1):setNumInputDims(3))
  319. model.w2nn_arch_name = "upconv_7"
  320. model.w2nn_offset = 14
  321. model.w2nn_scale_factor = 2
  322. model.w2nn_resize = true
  323. model.w2nn_channels = ch
  324. return model
  325. end
  326. -- large version of upconv_7
  327. -- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
  328. function srcnn.upconv_7l(backend, ch)
  329. local model = nn.Sequential()
  330. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  331. model:add(nn.LeakyReLU(0.1, true))
  332. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  333. model:add(nn.LeakyReLU(0.1, true))
  334. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  335. model:add(nn.LeakyReLU(0.1, true))
  336. model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
  337. model:add(nn.LeakyReLU(0.1, true))
  338. model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
  339. model:add(nn.LeakyReLU(0.1, true))
  340. model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
  341. model:add(nn.LeakyReLU(0.1, true))
  342. model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
  343. model:add(w2nn.InplaceClip01())
  344. model:add(nn.View(-1):setNumInputDims(3))
  345. model.w2nn_arch_name = "upconv_7l"
  346. model.w2nn_offset = 14
  347. model.w2nn_scale_factor = 2
  348. model.w2nn_resize = true
  349. model.w2nn_channels = ch
  350. --model:cuda()
  351. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  352. return model
  353. end
  354. -- layerwise linear blending with skip connections
  355. -- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l
  356. function srcnn.skiplb_7(backend, ch)
  357. local function skip(backend, i, o)
  358. local con = nn.Concat(2)
  359. local conv = nn.Sequential()
  360. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1))
  361. conv:add(nn.LeakyReLU(0.1, true))
  362. -- depth concat
  363. con:add(conv)
  364. con:add(nn.Identity()) -- skip
  365. return con
  366. end
  367. local model = nn.Sequential()
  368. model:add(skip(backend, ch, 16))
  369. model:add(skip(backend, 16+ch, 32))
  370. model:add(skip(backend, 32+16+ch, 64))
  371. model:add(skip(backend, 64+32+16+ch, 128))
  372. model:add(skip(backend, 128+64+32+16+ch, 128))
  373. model:add(skip(backend, 128+128+64+32+16+ch, 256))
  374. -- input of last layer = [all layerwise output(contains input layer)].flatten
  375. model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend
  376. model:add(w2nn.InplaceClip01())
  377. model:add(nn.View(-1):setNumInputDims(3))
  378. model.w2nn_arch_name = "skiplb_7"
  379. model.w2nn_offset = 14
  380. model.w2nn_scale_factor = 2
  381. model.w2nn_resize = true
  382. model.w2nn_channels = ch
  383. --model:cuda()
  384. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  385. return model
  386. end
  387. -- dilated convolution + deconvolution
  388. -- Note: This model is not better than upconv_7. Maybe becuase of under-fitting.
  389. function srcnn.dilated_upconv_7(backend, ch)
  390. local model = nn.Sequential()
  391. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  392. model:add(nn.LeakyReLU(0.1, true))
  393. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  394. model:add(nn.LeakyReLU(0.1, true))
  395. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  396. model:add(nn.LeakyReLU(0.1, true))
  397. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  398. model:add(nn.LeakyReLU(0.1, true))
  399. model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  400. model:add(nn.LeakyReLU(0.1, true))
  401. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  402. model:add(nn.LeakyReLU(0.1, true))
  403. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  404. model:add(w2nn.InplaceClip01())
  405. model:add(nn.View(-1):setNumInputDims(3))
  406. model.w2nn_arch_name = "dilated_upconv_7"
  407. model.w2nn_offset = 20
  408. model.w2nn_scale_factor = 2
  409. model.w2nn_resize = true
  410. model.w2nn_channels = ch
  411. --model:cuda()
  412. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  413. return model
  414. end
  415. -- ref: https://arxiv.org/abs/1609.04802
  416. -- note: no batch-norm, no zero-paading
  417. function srcnn.srresnet_2x(backend, ch)
  418. local function resblock(backend)
  419. local seq = nn.Sequential()
  420. local con = nn.ConcatTable()
  421. local conv = nn.Sequential()
  422. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  423. conv:add(ReLU(backend))
  424. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  425. conv:add(ReLU(backend))
  426. con:add(conv)
  427. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  428. seq:add(con)
  429. seq:add(nn.CAddTable())
  430. return seq
  431. end
  432. local model = nn.Sequential()
  433. --model:add(skip(backend, ch, 64 - ch))
  434. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  435. model:add(nn.LeakyReLU(0.1, true))
  436. model:add(resblock(backend))
  437. model:add(resblock(backend))
  438. model:add(resblock(backend))
  439. model:add(resblock(backend))
  440. model:add(resblock(backend))
  441. model:add(resblock(backend))
  442. model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2))
  443. model:add(ReLU(backend))
  444. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  445. model:add(w2nn.InplaceClip01())
  446. --model:add(nn.View(-1):setNumInputDims(3))
  447. model.w2nn_arch_name = "srresnet_2x"
  448. model.w2nn_offset = 28
  449. model.w2nn_scale_factor = 2
  450. model.w2nn_resize = true
  451. model.w2nn_channels = ch
  452. --model:cuda()
  453. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  454. return model
  455. end
  456. -- large version of srresnet_2x. It's current best model but slow.
  457. function srcnn.resnet_14l(backend, ch)
  458. local function resblock(backend, i, o)
  459. local seq = nn.Sequential()
  460. local con = nn.ConcatTable()
  461. local conv = nn.Sequential()
  462. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  463. conv:add(nn.LeakyReLU(0.1, true))
  464. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  465. conv:add(nn.LeakyReLU(0.1, true))
  466. con:add(conv)
  467. if i == o then
  468. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  469. else
  470. local seq = nn.Sequential()
  471. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  472. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  473. con:add(seq)
  474. end
  475. seq:add(con)
  476. seq:add(nn.CAddTable())
  477. return seq
  478. end
  479. local model = nn.Sequential()
  480. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  481. model:add(nn.LeakyReLU(0.1, true))
  482. model:add(resblock(backend, 32, 64))
  483. model:add(resblock(backend, 64, 64))
  484. model:add(resblock(backend, 64, 128))
  485. model:add(resblock(backend, 128, 128))
  486. model:add(resblock(backend, 128, 256))
  487. model:add(resblock(backend, 256, 256))
  488. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  489. model:add(w2nn.InplaceClip01())
  490. model:add(nn.View(-1):setNumInputDims(3))
  491. model.w2nn_arch_name = "resnet_14l"
  492. model.w2nn_offset = 28
  493. model.w2nn_scale_factor = 2
  494. model.w2nn_resize = true
  495. model.w2nn_channels = ch
  496. --model:cuda()
  497. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  498. return model
  499. end
  500. -- for segmentation
  501. function srcnn.fcn_v1(backend, ch)
  502. -- input_size = 120
  503. local model = nn.Sequential()
  504. --i = 120
  505. --model:cuda()
  506. --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
  507. model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
  508. model:add(nn.LeakyReLU(0.1, true))
  509. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  510. model:add(nn.LeakyReLU(0.1, true))
  511. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  512. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  513. model:add(nn.LeakyReLU(0.1, true))
  514. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  515. model:add(nn.LeakyReLU(0.1, true))
  516. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  517. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  518. model:add(nn.LeakyReLU(0.1, true))
  519. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  520. model:add(nn.LeakyReLU(0.1, true))
  521. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  522. model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
  523. model:add(nn.LeakyReLU(0.1, true))
  524. model:add(nn.Dropout(0.5, false, true))
  525. model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
  526. model:add(nn.LeakyReLU(0.1, true))
  527. model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
  528. model:add(nn.LeakyReLU(0.1, true))
  529. model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
  530. model:add(nn.LeakyReLU(0.1, true))
  531. model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
  532. model:add(nn.LeakyReLU(0.1, true))
  533. model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
  534. model:add(nn.LeakyReLU(0.1, true))
  535. model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
  536. model:add(w2nn.InplaceClip01())
  537. model:add(nn.View(-1):setNumInputDims(3))
  538. model.w2nn_arch_name = "fcn_v1"
  539. model.w2nn_offset = 36
  540. model.w2nn_scale_factor = 1
  541. model.w2nn_channels = ch
  542. model.w2nn_input_size = 120
  543. --model.w2nn_gcn = true
  544. return model
  545. end
  546. function srcnn.cupconv_14(backend, ch)
  547. local function skip(backend, n_input, n_output, pad)
  548. local con = nn.ConcatTable()
  549. local conv = nn.Sequential()
  550. local depad = nn.Sequential()
  551. conv:add(nn.SelectTable(1))
  552. conv:add(SpatialConvolution(backend, n_input, n_output, 3, 3, 1, 1, 0, 0))
  553. conv:add(nn.LeakyReLU(0.1, true))
  554. con:add(conv)
  555. con:add(nn.Identity())
  556. return con
  557. end
  558. local function concat(backend, n, ch, n_middle)
  559. local con = nn.ConcatTable()
  560. for i = 1, n do
  561. local pad = i - 1
  562. if i == 1 then
  563. con:add(nn.Sequential():add(nn.SelectTable(i)))
  564. else
  565. local seq = nn.Sequential()
  566. seq:add(nn.SelectTable(i))
  567. if pad > 0 then
  568. seq:add(nn.SpatialZeroPadding(-pad, -pad, -pad, -pad))
  569. end
  570. if i == n then
  571. --seq:add(SpatialConvolution(backend, ch, n_middle, 1, 1, 1, 1, 0, 0))
  572. else
  573. seq:add(w2nn.GradWeight(0.025))
  574. seq:add(SpatialConvolution(backend, n_middle, n_middle, 1, 1, 1, 1, 0, 0))
  575. end
  576. seq:add(nn.LeakyReLU(0.1, true))
  577. con:add(seq)
  578. end
  579. end
  580. return nn.Sequential():add(con):add(nn.JoinTable(2))
  581. end
  582. local model = nn.Sequential()
  583. local m = 64
  584. local n = 14
  585. model:add(nn.ConcatTable():add(nn.Identity()))
  586. for i = 1, n - 1 do
  587. if i == 1 then
  588. model:add(skip(backend, ch, m))
  589. else
  590. model:add(skip(backend, m, m))
  591. end
  592. end
  593. model:add(nn.FlattenTable())
  594. model:add(concat(backend, n, ch, m))
  595. model:add(SpatialFullConvolution(backend, m * (n - 1) + 3, ch, 4, 4, 2, 2, 3, 3):noBias())
  596. model:add(w2nn.InplaceClip01())
  597. model:add(nn.View(-1):setNumInputDims(3))
  598. model.w2nn_arch_name = "cupconv_14"
  599. model.w2nn_offset = 28
  600. model.w2nn_scale_factor = 2
  601. model.w2nn_channels = ch
  602. model.w2nn_resize = true
  603. return model
  604. end
  605. function srcnn.upconv_refine(backend, ch)
  606. local function block(backend, ch)
  607. local seq = nn.Sequential()
  608. local con = nn.ConcatTable()
  609. local res = nn.Sequential()
  610. local base = nn.Sequential()
  611. local refine = nn.Sequential()
  612. local aux_con = nn.ConcatTable()
  613. res:add(w2nn.GradWeight(0.1))
  614. res:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  615. res:add(nn.LeakyReLU(0.1, true))
  616. res:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  617. res:add(nn.LeakyReLU(0.1, true))
  618. res:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  619. res:add(nn.LeakyReLU(0.1, true))
  620. res:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0):noBias())
  621. res:add(w2nn.InplaceClip01())
  622. res:add(nn.MulConstant(0.5))
  623. con:add(res)
  624. con:add(nn.Sequential():add(nn.SpatialZeroPadding(-4, -4, -4, -4)):add(nn.MulConstant(0.5)))
  625. -- main output
  626. refine:add(nn.CAddTable()) -- averaging
  627. refine:add(nn.View(-1):setNumInputDims(3))
  628. -- aux output
  629. base:add(nn.SelectTable(2))
  630. base:add(nn.MulConstant(2)) -- revert mul 0.5
  631. base:add(nn.View(-1):setNumInputDims(3))
  632. aux_con:add(refine)
  633. aux_con:add(base)
  634. seq:add(con)
  635. seq:add(aux_con)
  636. seq:add(w2nn.AuxiliaryLossTable(1))
  637. return seq
  638. end
  639. local model = nn.Sequential()
  640. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  641. model:add(nn.LeakyReLU(0.1, true))
  642. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  643. model:add(nn.LeakyReLU(0.1, true))
  644. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  645. model:add(nn.LeakyReLU(0.1, true))
  646. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  647. model:add(nn.LeakyReLU(0.1, true))
  648. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  649. model:add(nn.LeakyReLU(0.1, true))
  650. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  651. model:add(nn.LeakyReLU(0.1, true))
  652. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  653. model:add(w2nn.InplaceClip01())
  654. model:add(block(backend, ch))
  655. model.w2nn_arch_name = "upconv_refine"
  656. model.w2nn_offset = 18
  657. model.w2nn_scale_factor = 2
  658. model.w2nn_resize = true
  659. model.w2nn_channels = ch
  660. return model
  661. end
  662. -- cascaded residual channel attention unet
  663. function srcnn.upcunet(backend, ch)
  664. function unet_branch(insert, backend, n_input, n_output, depad)
  665. local block = nn.Sequential()
  666. local con = nn.ConcatTable(2)
  667. local model = nn.Sequential()
  668. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  669. block:add(insert)
  670. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  671. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  672. con:add(block)
  673. model:add(con)
  674. model:add(nn.CAddTable())
  675. return model
  676. end
  677. function unet_conv(n_input, n_middle, n_output, se)
  678. local model = nn.Sequential()
  679. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  680. model:add(nn.LeakyReLU(0.1, true))
  681. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  682. model:add(nn.LeakyReLU(0.1, true))
  683. if se then
  684. -- Squeeze and Excitation Networks
  685. local con = nn.ConcatTable(2)
  686. local attention = nn.Sequential()
  687. attention:add(nn.SpatialAdaptiveAveragePooling(1, 1)) -- global average pooling
  688. attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / 4), 1, 1, 1, 1, 0, 0))
  689. attention:add(nn.ReLU(true))
  690. attention:add(SpatialConvolution(backend, math.floor(n_output / 4), n_output, 1, 1, 1, 1, 0, 0))
  691. attention:add(nn.Sigmoid(true))
  692. con:add(nn.Identity())
  693. con:add(attention)
  694. model:add(con)
  695. model:add(w2nn.ScaleTable())
  696. end
  697. return model
  698. end
  699. -- Residual U-Net
  700. function unet(backend, ch, deconv)
  701. local block1 = unet_conv(128, 256, 128, true)
  702. local block2 = nn.Sequential()
  703. block2:add(unet_conv(64, 64, 128, true))
  704. block2:add(unet_branch(block1, backend, 128, 128, 4))
  705. block2:add(unet_conv(128, 64, 64, true))
  706. local model = nn.Sequential()
  707. model:add(unet_conv(ch, 32, 64, false))
  708. model:add(unet_branch(block2, backend, 64, 64, 16))
  709. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  710. model:add(nn.LeakyReLU(0.1))
  711. if deconv then
  712. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  713. else
  714. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  715. end
  716. return model
  717. end
  718. local model = nn.Sequential()
  719. local con = nn.ConcatTable()
  720. local aux_con = nn.ConcatTable()
  721. -- 2 cascade
  722. model:add(unet(backend, ch, true))
  723. con:add(unet(backend, ch, false))
  724. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  725. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  726. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  727. model:add(con)
  728. model:add(aux_con)
  729. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  730. model.w2nn_arch_name = "upcunet"
  731. model.w2nn_offset = 60
  732. model.w2nn_scale_factor = 2
  733. model.w2nn_channels = ch
  734. model.w2nn_resize = true
  735. -- 72, 128, 256 are valid
  736. --model.w2nn_input_size = 128
  737. return model
  738. end
  739. -- cascaded residual spatial channel attention unet
  740. function srcnn.upcunet_v2(backend, ch)
  741. function unet_branch(insert, backend, n_input, n_output, depad)
  742. local block = nn.Sequential()
  743. local con = nn.ConcatTable(2)
  744. local model = nn.Sequential()
  745. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  746. block:add(insert)
  747. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  748. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  749. con:add(block)
  750. model:add(con)
  751. model:add(nn.CAddTable())
  752. return model
  753. end
  754. function unet_conv(n_input, n_middle, n_output, se)
  755. local model = nn.Sequential()
  756. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  757. model:add(nn.LeakyReLU(0.1, true))
  758. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  759. model:add(nn.LeakyReLU(0.1, true))
  760. if se then
  761. -- Spatial Squeeze and Excitation Networks
  762. local se_fac = 4
  763. local con = nn.ConcatTable(2)
  764. local attention = nn.Sequential()
  765. attention:add(SpatialAveragePooling(backend, 4, 4, 4, 4))
  766. attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / se_fac), 1, 1, 1, 1, 0, 0))
  767. attention:add(nn.ReLU(true))
  768. attention:add(SpatialConvolution(backend, math.floor(n_output / se_fac), n_output, 1, 1, 1, 1, 0, 0))
  769. attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
  770. attention:add(nn.SpatialUpSamplingNearest(4, 4))
  771. con:add(nn.Identity())
  772. con:add(attention)
  773. model:add(con)
  774. model:add(nn.CMulTable())
  775. end
  776. return model
  777. end
  778. -- Residual U-Net
  779. function unet(backend, in_ch, out_ch, deconv)
  780. local block1 = unet_conv(128, 256, 128, true)
  781. local block2 = nn.Sequential()
  782. block2:add(unet_conv(64, 64, 128, true))
  783. block2:add(unet_branch(block1, backend, 128, 128, 4))
  784. block2:add(unet_conv(128, 64, 64, true))
  785. local model = nn.Sequential()
  786. model:add(unet_conv(in_ch, 32, 64, false))
  787. model:add(unet_branch(block2, backend, 64, 64, 16))
  788. if deconv then
  789. model:add(SpatialFullConvolution(backend, 64, out_ch, 4, 4, 2, 2, 3, 3):noBias())
  790. else
  791. model:add(SpatialConvolution(backend, 64, out_ch, 3, 3, 1, 1, 0, 0):noBias())
  792. end
  793. return model
  794. end
  795. local model = nn.Sequential()
  796. local con = nn.ConcatTable()
  797. local aux_con = nn.ConcatTable()
  798. -- 2 cascade
  799. model:add(unet(backend, ch, ch, true))
  800. con:add(nn.Sequential():add(unet(backend, ch, ch, false)):add(nn.SpatialZeroPadding(-1, -1, -1, -1))) -- -1 for odd output size
  801. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  802. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  803. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  804. model:add(con)
  805. model:add(aux_con)
  806. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  807. model.w2nn_arch_name = "upcunet_v2"
  808. model.w2nn_offset = 58
  809. model.w2nn_scale_factor = 2
  810. model.w2nn_channels = ch
  811. model.w2nn_resize = true
  812. -- {76,92,108,140} are also valid size but it is too small
  813. model.w2nn_valid_input_size = {156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508}
  814. return model
  815. end
  816. local function bench()
  817. local sys = require 'sys'
  818. cudnn.benchmark = false
  819. local model = nil
  820. local arch = {"upconv_7", "upcunet", "upcunet_v2"}
  821. local backend = "cunn"
  822. for k = 1, #arch do
  823. model = srcnn[arch[k]](backend, 3):cuda()
  824. model:training()
  825. t = sys.clock()
  826. for i = 1, 10 do
  827. model:forward(torch.Tensor(1, 3, 172, 172):zero():cuda())
  828. end
  829. print(arch[k], sys.clock() - t)
  830. end
  831. end
  832. function srcnn.create(model_name, backend, color)
  833. model_name = model_name or "vgg_7"
  834. backend = backend or "cunn"
  835. color = color or "rgb"
  836. local ch = 3
  837. if color == "rgb" then
  838. ch = 3
  839. elseif color == "y" then
  840. ch = 1
  841. else
  842. error("unsupported color: " .. color)
  843. end
  844. if srcnn[model_name] then
  845. local model = srcnn[model_name](backend, ch)
  846. assert(model.w2nn_offset % model.w2nn_scale_factor == 0)
  847. return model
  848. else
  849. error("unsupported model_name: " .. model_name)
  850. end
  851. end
  852. --[[
  853. local model = srcnn.cunet_v3("cunn", 3):cuda()
  854. print(model)
  855. model:training()
  856. print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()):size())
  857. local model = srcnn.upcunet_v2("cunn", 3):cuda()
  858. print(model)
  859. model:training()
  860. print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda()))
  861. os.exit()
  862. --]]
  863. return srcnn