srcnn.lua 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. local function msra_filler(mod)
  6. local fin = mod.kW * mod.kH * mod.nInputPlane
  7. local fout = mod.kW * mod.kH * mod.nOutputPlane
  8. stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
  9. mod.weight:normal(0, stdv)
  10. mod.bias:zero()
  11. end
  12. local function identity_filler(mod)
  13. assert(mod.nInputPlane <= mod.nOutputPlane)
  14. mod.weight:normal(0, 0.01)
  15. mod.bias:zero()
  16. local num_groups = mod.nInputPlane -- fixed
  17. local filler_value = num_groups / mod.nOutputPlane
  18. local in_group_size = math.floor(mod.nInputPlane / num_groups)
  19. local out_group_size = math.floor(mod.nOutputPlane / num_groups)
  20. local x = math.floor(mod.kW / 2)
  21. local y = math.floor(mod.kH / 2)
  22. for i = 0, num_groups - 1 do
  23. for j = i * out_group_size, (i + 1) * out_group_size - 1 do
  24. for k = i * in_group_size, (i + 1) * in_group_size - 1 do
  25. mod.weight[j+1][k+1][y+1][x+1] = filler_value
  26. end
  27. end
  28. end
  29. end
  30. function nn.SpatialConvolutionMM:reset(stdv)
  31. msra_filler(self)
  32. end
  33. function nn.SpatialFullConvolution:reset(stdv)
  34. msra_filler(self)
  35. end
  36. function nn.SpatialDilatedConvolution:reset(stdv)
  37. identity_filler(self)
  38. end
  39. if cudnn and cudnn.SpatialConvolution then
  40. function cudnn.SpatialConvolution:reset(stdv)
  41. msra_filler(self)
  42. end
  43. function cudnn.SpatialFullConvolution:reset(stdv)
  44. msra_filler(self)
  45. end
  46. if cudnn.SpatialDilatedConvolution then
  47. function cudnn.SpatialDilatedConvolution:reset(stdv)
  48. identity_filler(self)
  49. end
  50. end
  51. end
  52. function nn.SpatialConvolutionMM:clearState()
  53. if self.gradWeight then
  54. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  55. end
  56. if self.gradBias then
  57. self.gradBias:resize(self.nOutputPlane):zero()
  58. end
  59. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  60. end
  61. function srcnn.channels(model)
  62. if model.w2nn_channels ~= nil then
  63. return model.w2nn_channels
  64. else
  65. return model:get(model:size() - 1).weight:size(1)
  66. end
  67. end
  68. function srcnn.backend(model)
  69. local conv = model:findModules("cudnn.SpatialConvolution")
  70. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  71. if #conv > 0 or #fullconv > 0 then
  72. return "cudnn"
  73. else
  74. return "cunn"
  75. end
  76. end
  77. function srcnn.color(model)
  78. local ch = srcnn.channels(model)
  79. if ch == 3 then
  80. return "rgb"
  81. else
  82. return "y"
  83. end
  84. end
  85. function srcnn.name(model)
  86. if model.w2nn_arch_name ~= nil then
  87. return model.w2nn_arch_name
  88. else
  89. local conv = model:findModules("nn.SpatialConvolutionMM")
  90. if #conv == 0 then
  91. conv = model:findModules("cudnn.SpatialConvolution")
  92. end
  93. if #conv == 7 then
  94. return "vgg_7"
  95. elseif #conv == 12 then
  96. return "vgg_12"
  97. else
  98. error("unsupported model")
  99. end
  100. end
  101. end
  102. function srcnn.offset_size(model)
  103. if model.w2nn_offset ~= nil then
  104. return model.w2nn_offset
  105. else
  106. local name = srcnn.name(model)
  107. if name:match("vgg_") then
  108. local conv = model:findModules("nn.SpatialConvolutionMM")
  109. if #conv == 0 then
  110. conv = model:findModules("cudnn.SpatialConvolution")
  111. end
  112. local offset = 0
  113. for i = 1, #conv do
  114. offset = offset + (conv[i].kW - 1) / 2
  115. end
  116. return math.floor(offset)
  117. else
  118. error("unsupported model")
  119. end
  120. end
  121. end
  122. function srcnn.scale_factor(model)
  123. if model.w2nn_scale_factor ~= nil then
  124. return model.w2nn_scale_factor
  125. else
  126. local name = srcnn.name(model)
  127. if name == "upconv_7" then
  128. return 2
  129. elseif name == "upconv_8_4x" then
  130. return 4
  131. else
  132. return 1
  133. end
  134. end
  135. end
  136. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  137. if backend == "cunn" then
  138. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  139. elseif backend == "cudnn" then
  140. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  141. else
  142. error("unsupported backend:" .. backend)
  143. end
  144. end
  145. srcnn.SpatialConvolution = SpatialConvolution
  146. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  147. if backend == "cunn" then
  148. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  149. elseif backend == "cudnn" then
  150. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  151. else
  152. error("unsupported backend:" .. backend)
  153. end
  154. end
  155. srcnn.SpatialFullConvolution = SpatialFullConvolution
  156. local function ReLU(backend)
  157. if backend == "cunn" then
  158. return nn.ReLU(true)
  159. elseif backend == "cudnn" then
  160. return cudnn.ReLU(true)
  161. else
  162. error("unsupported backend:" .. backend)
  163. end
  164. end
  165. srcnn.ReLU = ReLU
  166. local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
  167. if backend == "cunn" then
  168. return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  169. elseif backend == "cudnn" then
  170. return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  171. else
  172. error("unsupported backend:" .. backend)
  173. end
  174. end
  175. srcnn.SpatialMaxPooling = SpatialMaxPooling
  176. local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
  177. if backend == "cunn" then
  178. return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  179. elseif backend == "cudnn" then
  180. return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  181. else
  182. error("unsupported backend:" .. backend)
  183. end
  184. end
  185. srcnn.SpatialAveragePooling = SpatialAveragePooling
  186. local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  187. if backend == "cunn" then
  188. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  189. elseif backend == "cudnn" then
  190. if cudnn.SpatialDilatedConvolution then
  191. -- cudnn v 6
  192. return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  193. else
  194. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  195. end
  196. else
  197. error("unsupported backend:" .. backend)
  198. end
  199. end
  200. srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
  201. -- VGG style net(7 layers)
  202. function srcnn.vgg_7(backend, ch)
  203. local model = nn.Sequential()
  204. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  205. model:add(nn.LeakyReLU(0.1, true))
  206. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  207. model:add(nn.LeakyReLU(0.1, true))
  208. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  209. model:add(nn.LeakyReLU(0.1, true))
  210. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  211. model:add(nn.LeakyReLU(0.1, true))
  212. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  213. model:add(nn.LeakyReLU(0.1, true))
  214. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  215. model:add(nn.LeakyReLU(0.1, true))
  216. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  217. model:add(w2nn.InplaceClip01())
  218. model:add(nn.View(-1):setNumInputDims(3))
  219. model.w2nn_arch_name = "vgg_7"
  220. model.w2nn_offset = 7
  221. model.w2nn_scale_factor = 1
  222. model.w2nn_channels = ch
  223. --model:cuda()
  224. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  225. return model
  226. end
  227. -- VGG style net(12 layers)
  228. function srcnn.vgg_12(backend, ch)
  229. local model = nn.Sequential()
  230. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  231. model:add(nn.LeakyReLU(0.1, true))
  232. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  233. model:add(nn.LeakyReLU(0.1, true))
  234. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  235. model:add(nn.LeakyReLU(0.1, true))
  236. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  237. model:add(nn.LeakyReLU(0.1, true))
  238. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  239. model:add(nn.LeakyReLU(0.1, true))
  240. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  241. model:add(nn.LeakyReLU(0.1, true))
  242. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  243. model:add(nn.LeakyReLU(0.1, true))
  244. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  245. model:add(nn.LeakyReLU(0.1, true))
  246. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  247. model:add(nn.LeakyReLU(0.1, true))
  248. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  249. model:add(nn.LeakyReLU(0.1, true))
  250. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  251. model:add(nn.LeakyReLU(0.1, true))
  252. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  253. model:add(w2nn.InplaceClip01())
  254. model:add(nn.View(-1):setNumInputDims(3))
  255. model.w2nn_arch_name = "vgg_12"
  256. model.w2nn_offset = 12
  257. model.w2nn_scale_factor = 1
  258. model.w2nn_resize = false
  259. model.w2nn_channels = ch
  260. --model:cuda()
  261. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  262. return model
  263. end
  264. -- Dilated Convolution (7 layers)
  265. function srcnn.dilated_7(backend, ch)
  266. local model = nn.Sequential()
  267. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  268. model:add(nn.LeakyReLU(0.1, true))
  269. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  270. model:add(nn.LeakyReLU(0.1, true))
  271. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  272. model:add(nn.LeakyReLU(0.1, true))
  273. model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  274. model:add(nn.LeakyReLU(0.1, true))
  275. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
  276. model:add(nn.LeakyReLU(0.1, true))
  277. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  278. model:add(nn.LeakyReLU(0.1, true))
  279. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  280. model:add(w2nn.InplaceClip01())
  281. model:add(nn.View(-1):setNumInputDims(3))
  282. model.w2nn_arch_name = "dilated_7"
  283. model.w2nn_offset = 12
  284. model.w2nn_scale_factor = 1
  285. model.w2nn_resize = false
  286. model.w2nn_channels = ch
  287. --model:cuda()
  288. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  289. return model
  290. end
  291. -- Upconvolution
  292. function srcnn.upconv_7(backend, ch)
  293. local model = nn.Sequential()
  294. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  295. model:add(nn.LeakyReLU(0.1, true))
  296. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  297. model:add(nn.LeakyReLU(0.1, true))
  298. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  299. model:add(nn.LeakyReLU(0.1, true))
  300. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  301. model:add(nn.LeakyReLU(0.1, true))
  302. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  303. model:add(nn.LeakyReLU(0.1, true))
  304. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  305. model:add(nn.LeakyReLU(0.1, true))
  306. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  307. model:add(w2nn.InplaceClip01())
  308. model:add(nn.View(-1):setNumInputDims(3))
  309. model.w2nn_arch_name = "upconv_7"
  310. model.w2nn_offset = 14
  311. model.w2nn_scale_factor = 2
  312. model.w2nn_resize = true
  313. model.w2nn_channels = ch
  314. return model
  315. end
  316. -- large version of upconv_7
  317. -- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
  318. function srcnn.upconv_7l(backend, ch)
  319. local model = nn.Sequential()
  320. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  321. model:add(nn.LeakyReLU(0.1, true))
  322. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  323. model:add(nn.LeakyReLU(0.1, true))
  324. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  325. model:add(nn.LeakyReLU(0.1, true))
  326. model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
  327. model:add(nn.LeakyReLU(0.1, true))
  328. model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
  329. model:add(nn.LeakyReLU(0.1, true))
  330. model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
  331. model:add(nn.LeakyReLU(0.1, true))
  332. model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
  333. model:add(w2nn.InplaceClip01())
  334. model:add(nn.View(-1):setNumInputDims(3))
  335. model.w2nn_arch_name = "upconv_7l"
  336. model.w2nn_offset = 14
  337. model.w2nn_scale_factor = 2
  338. model.w2nn_resize = true
  339. model.w2nn_channels = ch
  340. --model:cuda()
  341. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  342. return model
  343. end
  344. -- layerwise linear blending with skip connections
  345. -- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l
  346. function srcnn.skiplb_7(backend, ch)
  347. local function skip(backend, i, o)
  348. local con = nn.Concat(2)
  349. local conv = nn.Sequential()
  350. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1))
  351. conv:add(nn.LeakyReLU(0.1, true))
  352. -- depth concat
  353. con:add(conv)
  354. con:add(nn.Identity()) -- skip
  355. return con
  356. end
  357. local model = nn.Sequential()
  358. model:add(skip(backend, ch, 16))
  359. model:add(skip(backend, 16+ch, 32))
  360. model:add(skip(backend, 32+16+ch, 64))
  361. model:add(skip(backend, 64+32+16+ch, 128))
  362. model:add(skip(backend, 128+64+32+16+ch, 128))
  363. model:add(skip(backend, 128+128+64+32+16+ch, 256))
  364. -- input of last layer = [all layerwise output(contains input layer)].flatten
  365. model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend
  366. model:add(w2nn.InplaceClip01())
  367. model:add(nn.View(-1):setNumInputDims(3))
  368. model.w2nn_arch_name = "skiplb_7"
  369. model.w2nn_offset = 14
  370. model.w2nn_scale_factor = 2
  371. model.w2nn_resize = true
  372. model.w2nn_channels = ch
  373. --model:cuda()
  374. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  375. return model
  376. end
  377. -- dilated convolution + deconvolution
  378. -- Note: This model is not better than upconv_7. Maybe becuase of under-fitting.
  379. function srcnn.dilated_upconv_7(backend, ch)
  380. local model = nn.Sequential()
  381. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  382. model:add(nn.LeakyReLU(0.1, true))
  383. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  384. model:add(nn.LeakyReLU(0.1, true))
  385. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  386. model:add(nn.LeakyReLU(0.1, true))
  387. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  388. model:add(nn.LeakyReLU(0.1, true))
  389. model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  390. model:add(nn.LeakyReLU(0.1, true))
  391. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  392. model:add(nn.LeakyReLU(0.1, true))
  393. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  394. model:add(w2nn.InplaceClip01())
  395. model:add(nn.View(-1):setNumInputDims(3))
  396. model.w2nn_arch_name = "dilated_upconv_7"
  397. model.w2nn_offset = 20
  398. model.w2nn_scale_factor = 2
  399. model.w2nn_resize = true
  400. model.w2nn_channels = ch
  401. --model:cuda()
  402. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  403. return model
  404. end
  405. -- ref: https://arxiv.org/abs/1609.04802
  406. -- note: no batch-norm, no zero-paading
  407. function srcnn.srresnet_2x(backend, ch)
  408. local function resblock(backend)
  409. local seq = nn.Sequential()
  410. local con = nn.ConcatTable()
  411. local conv = nn.Sequential()
  412. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  413. conv:add(ReLU(backend))
  414. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  415. conv:add(ReLU(backend))
  416. con:add(conv)
  417. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  418. seq:add(con)
  419. seq:add(nn.CAddTable())
  420. return seq
  421. end
  422. local model = nn.Sequential()
  423. --model:add(skip(backend, ch, 64 - ch))
  424. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  425. model:add(nn.LeakyReLU(0.1, true))
  426. model:add(resblock(backend))
  427. model:add(resblock(backend))
  428. model:add(resblock(backend))
  429. model:add(resblock(backend))
  430. model:add(resblock(backend))
  431. model:add(resblock(backend))
  432. model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2))
  433. model:add(ReLU(backend))
  434. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  435. model:add(w2nn.InplaceClip01())
  436. --model:add(nn.View(-1):setNumInputDims(3))
  437. model.w2nn_arch_name = "srresnet_2x"
  438. model.w2nn_offset = 28
  439. model.w2nn_scale_factor = 2
  440. model.w2nn_resize = true
  441. model.w2nn_channels = ch
  442. --model:cuda()
  443. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  444. return model
  445. end
  446. -- large version of srresnet_2x. It's current best model but slow.
  447. function srcnn.resnet_14l(backend, ch)
  448. local function resblock(backend, i, o)
  449. local seq = nn.Sequential()
  450. local con = nn.ConcatTable()
  451. local conv = nn.Sequential()
  452. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  453. conv:add(nn.LeakyReLU(0.1, true))
  454. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  455. conv:add(nn.LeakyReLU(0.1, true))
  456. con:add(conv)
  457. if i == o then
  458. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  459. else
  460. local seq = nn.Sequential()
  461. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  462. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  463. con:add(seq)
  464. end
  465. seq:add(con)
  466. seq:add(nn.CAddTable())
  467. return seq
  468. end
  469. local model = nn.Sequential()
  470. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  471. model:add(nn.LeakyReLU(0.1, true))
  472. model:add(resblock(backend, 32, 64))
  473. model:add(resblock(backend, 64, 64))
  474. model:add(resblock(backend, 64, 128))
  475. model:add(resblock(backend, 128, 128))
  476. model:add(resblock(backend, 128, 256))
  477. model:add(resblock(backend, 256, 256))
  478. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  479. model:add(w2nn.InplaceClip01())
  480. model:add(nn.View(-1):setNumInputDims(3))
  481. model.w2nn_arch_name = "resnet_14l"
  482. model.w2nn_offset = 28
  483. model.w2nn_scale_factor = 2
  484. model.w2nn_resize = true
  485. model.w2nn_channels = ch
  486. --model:cuda()
  487. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  488. return model
  489. end
  490. -- for segmentation
  491. function srcnn.fcn_v1(backend, ch)
  492. -- input_size = 120
  493. local model = nn.Sequential()
  494. --i = 120
  495. --model:cuda()
  496. --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
  497. model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
  498. model:add(nn.LeakyReLU(0.1, true))
  499. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  500. model:add(nn.LeakyReLU(0.1, true))
  501. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  502. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  503. model:add(nn.LeakyReLU(0.1, true))
  504. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  505. model:add(nn.LeakyReLU(0.1, true))
  506. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  507. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  508. model:add(nn.LeakyReLU(0.1, true))
  509. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  510. model:add(nn.LeakyReLU(0.1, true))
  511. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  512. model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
  513. model:add(nn.LeakyReLU(0.1, true))
  514. model:add(nn.Dropout(0.5, false, true))
  515. model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
  516. model:add(nn.LeakyReLU(0.1, true))
  517. model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
  518. model:add(nn.LeakyReLU(0.1, true))
  519. model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
  520. model:add(nn.LeakyReLU(0.1, true))
  521. model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
  522. model:add(nn.LeakyReLU(0.1, true))
  523. model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
  524. model:add(nn.LeakyReLU(0.1, true))
  525. model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
  526. model:add(w2nn.InplaceClip01())
  527. model:add(nn.View(-1):setNumInputDims(3))
  528. model.w2nn_arch_name = "fcn_v1"
  529. model.w2nn_offset = 36
  530. model.w2nn_scale_factor = 1
  531. model.w2nn_channels = ch
  532. model.w2nn_input_size = 120
  533. --model.w2nn_gcn = true
  534. return model
  535. end
  536. function srcnn.cupconv_14(backend, ch)
  537. local function skip(backend, n_input, n_output, pad)
  538. local con = nn.ConcatTable()
  539. local conv = nn.Sequential()
  540. local depad = nn.Sequential()
  541. conv:add(nn.SelectTable(1))
  542. conv:add(SpatialConvolution(backend, n_input, n_output, 3, 3, 1, 1, 0, 0))
  543. conv:add(nn.LeakyReLU(0.1, true))
  544. con:add(conv)
  545. con:add(nn.Identity())
  546. return con
  547. end
  548. local function concat(backend, n, ch, n_middle)
  549. local con = nn.ConcatTable()
  550. for i = 1, n do
  551. local pad = i - 1
  552. if i == 1 then
  553. con:add(nn.Sequential():add(nn.SelectTable(i)))
  554. else
  555. local seq = nn.Sequential()
  556. seq:add(nn.SelectTable(i))
  557. if pad > 0 then
  558. seq:add(nn.SpatialZeroPadding(-pad, -pad, -pad, -pad))
  559. end
  560. if i == n then
  561. --seq:add(SpatialConvolution(backend, ch, n_middle, 1, 1, 1, 1, 0, 0))
  562. else
  563. seq:add(w2nn.GradWeight(0.025))
  564. seq:add(SpatialConvolution(backend, n_middle, n_middle, 1, 1, 1, 1, 0, 0))
  565. end
  566. seq:add(nn.LeakyReLU(0.1, true))
  567. con:add(seq)
  568. end
  569. end
  570. return nn.Sequential():add(con):add(nn.JoinTable(2))
  571. end
  572. local model = nn.Sequential()
  573. local m = 64
  574. local n = 14
  575. model:add(nn.ConcatTable():add(nn.Identity()))
  576. for i = 1, n - 1 do
  577. if i == 1 then
  578. model:add(skip(backend, ch, m))
  579. else
  580. model:add(skip(backend, m, m))
  581. end
  582. end
  583. model:add(nn.FlattenTable())
  584. model:add(concat(backend, n, ch, m))
  585. model:add(SpatialFullConvolution(backend, m * (n - 1) + 3, ch, 4, 4, 2, 2, 3, 3):noBias())
  586. model:add(w2nn.InplaceClip01())
  587. model:add(nn.View(-1):setNumInputDims(3))
  588. model.w2nn_arch_name = "cupconv_14"
  589. model.w2nn_offset = 28
  590. model.w2nn_scale_factor = 2
  591. model.w2nn_channels = ch
  592. model.w2nn_resize = true
  593. return model
  594. end
  595. function srcnn.upconv_refine(backend, ch)
  596. local function block(backend, ch)
  597. local seq = nn.Sequential()
  598. local con = nn.ConcatTable()
  599. local res = nn.Sequential()
  600. local base = nn.Sequential()
  601. local refine = nn.Sequential()
  602. local aux_con = nn.ConcatTable()
  603. res:add(w2nn.GradWeight(0.1))
  604. res:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  605. res:add(nn.LeakyReLU(0.1, true))
  606. res:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  607. res:add(nn.LeakyReLU(0.1, true))
  608. res:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  609. res:add(nn.LeakyReLU(0.1, true))
  610. res:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0):noBias())
  611. res:add(w2nn.InplaceClip01())
  612. res:add(nn.MulConstant(0.5))
  613. con:add(res)
  614. con:add(nn.Sequential():add(nn.SpatialZeroPadding(-4, -4, -4, -4)):add(nn.MulConstant(0.5)))
  615. -- main output
  616. refine:add(nn.CAddTable()) -- averaging
  617. refine:add(nn.View(-1):setNumInputDims(3))
  618. -- aux output
  619. base:add(nn.SelectTable(2))
  620. base:add(nn.MulConstant(2)) -- revert mul 0.5
  621. base:add(nn.View(-1):setNumInputDims(3))
  622. aux_con:add(refine)
  623. aux_con:add(base)
  624. seq:add(con)
  625. seq:add(aux_con)
  626. seq:add(w2nn.AuxiliaryLossTable(1))
  627. return seq
  628. end
  629. local model = nn.Sequential()
  630. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  631. model:add(nn.LeakyReLU(0.1, true))
  632. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  633. model:add(nn.LeakyReLU(0.1, true))
  634. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  635. model:add(nn.LeakyReLU(0.1, true))
  636. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  637. model:add(nn.LeakyReLU(0.1, true))
  638. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  639. model:add(nn.LeakyReLU(0.1, true))
  640. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  641. model:add(nn.LeakyReLU(0.1, true))
  642. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  643. model:add(w2nn.InplaceClip01())
  644. model:add(block(backend, ch))
  645. model.w2nn_arch_name = "upconv_refine"
  646. model.w2nn_offset = 18
  647. model.w2nn_scale_factor = 2
  648. model.w2nn_resize = true
  649. model.w2nn_channels = ch
  650. return model
  651. end
  652. -- cascaded residual channel attention unet
  653. function srcnn.upcunet(backend, ch)
  654. function unet_branch(insert, backend, n_input, n_output, depad)
  655. local block = nn.Sequential()
  656. local con = nn.ConcatTable(2)
  657. local model = nn.Sequential()
  658. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  659. block:add(insert)
  660. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  661. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  662. con:add(block)
  663. model:add(con)
  664. model:add(nn.CAddTable())
  665. return model
  666. end
  667. function unet_conv(n_input, n_middle, n_output, se)
  668. local model = nn.Sequential()
  669. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  670. model:add(nn.LeakyReLU(0.1, true))
  671. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  672. model:add(nn.LeakyReLU(0.1, true))
  673. if se then
  674. -- Squeeze and Excitation Networks
  675. local con = nn.ConcatTable(2)
  676. local attention = nn.Sequential()
  677. attention:add(nn.SpatialAdaptiveAveragePooling(1, 1)) -- global average pooling
  678. attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / 4), 1, 1, 1, 1, 0, 0))
  679. attention:add(nn.ReLU(true))
  680. attention:add(SpatialConvolution(backend, math.floor(n_output / 4), n_output, 1, 1, 1, 1, 0, 0))
  681. attention:add(nn.Sigmoid(true))
  682. con:add(nn.Identity())
  683. con:add(attention)
  684. model:add(con)
  685. model:add(w2nn.ScaleTable())
  686. end
  687. return model
  688. end
  689. -- Residual U-Net
  690. function unet(backend, ch, deconv)
  691. local block1 = unet_conv(128, 256, 128, true)
  692. local block2 = nn.Sequential()
  693. block2:add(unet_conv(64, 64, 128, true))
  694. block2:add(unet_branch(block1, backend, 128, 128, 4))
  695. block2:add(unet_conv(128, 64, 64, true))
  696. local model = nn.Sequential()
  697. model:add(unet_conv(ch, 32, 64, false))
  698. model:add(unet_branch(block2, backend, 64, 64, 16))
  699. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  700. model:add(nn.LeakyReLU(0.1))
  701. if deconv then
  702. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  703. else
  704. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  705. end
  706. return model
  707. end
  708. local model = nn.Sequential()
  709. local con = nn.ConcatTable()
  710. local aux_con = nn.ConcatTable()
  711. -- 2 cascade
  712. model:add(unet(backend, ch, true))
  713. con:add(unet(backend, ch, false))
  714. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  715. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  716. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  717. model:add(con)
  718. model:add(aux_con)
  719. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  720. model.w2nn_arch_name = "upcunet"
  721. model.w2nn_offset = 60
  722. model.w2nn_scale_factor = 2
  723. model.w2nn_channels = ch
  724. model.w2nn_resize = true
  725. -- 72, 128, 256 are valid
  726. --model.w2nn_input_size = 128
  727. return model
  728. end
  729. function srcnn.prog_net(backend, ch)
  730. function base_upscaler(backend, ch)
  731. local model = nn.Sequential()
  732. model:add(nn.SpatialZeroPadding(-11, -11, -11, -11))
  733. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  734. model:add(nn.LeakyReLU(0.1, true))
  735. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  736. model:add(nn.LeakyReLU(0.1, true))
  737. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3):noBias())
  738. model:add(w2nn.InplaceClip01())
  739. return model
  740. end
  741. function block(backend, input, output)
  742. local con = nn.ConcatTable()
  743. local conv = nn.Sequential()
  744. local dil = nn.Sequential()
  745. local b = nn.Sequential()
  746. conv:add(SpatialConvolution(backend, input, output, 3, 3, 1, 1, 0, 0))
  747. conv:add(nn.SpatialZeroPadding(-5, -5, -5, -5))
  748. dil:add(SpatialDilatedConvolution(backend, input, output, 3, 3, 1, 1, 0, 0, 2, 2))
  749. dil:add(nn.LeakyReLU(0.1, true))
  750. dil:add(SpatialDilatedConvolution(backend, output, output, 3, 3, 1, 1, 0, 0, 4, 4))
  751. con:add(conv)
  752. con:add(dil)
  753. b:add(con)
  754. b:add(nn.CAddTable())
  755. b:add(nn.LeakyReLU(0.1, true))
  756. return b
  757. end
  758. function texture_upscaler(backend, ch)
  759. local model = nn.Sequential()
  760. model:add(w2nn.EdgeFilter(ch))
  761. model:add(SpatialConvolution(backend, ch * 8, 32, 1, 1, 1, 1, 0, 0))
  762. model:add(nn.LeakyReLU(0.1, true))
  763. model:add(block(backend, 32, 128))
  764. model:add(block(backend, 128, 256))
  765. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  766. return model
  767. end
  768. local model = nn.Sequential()
  769. local con = nn.ConcatTable()
  770. local aux = nn.ConcatTable()
  771. con:add(base_upscaler(backend, ch))
  772. con:add(texture_upscaler(backend, ch))
  773. aux:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01()):add(nn.View(-1):setNumInputDims(3)))
  774. aux:add(nn.Sequential():add(nn.SelectTable(1)):add(nn.View(-1):setNumInputDims(3)))
  775. model:add(con)
  776. model:add(aux)
  777. model:add(w2nn.AuxiliaryLossTable(1))
  778. model.w2nn_arch_name = "prog_net"
  779. model.w2nn_offset = 28
  780. model.w2nn_scale_factor = 2
  781. model.w2nn_channels = ch
  782. model.w2nn_resize = true
  783. return model
  784. end
  785. function srcnn.create(model_name, backend, color)
  786. model_name = model_name or "vgg_7"
  787. backend = backend or "cunn"
  788. color = color or "rgb"
  789. local ch = 3
  790. if color == "rgb" then
  791. ch = 3
  792. elseif color == "y" then
  793. ch = 1
  794. else
  795. error("unsupported color: " .. color)
  796. end
  797. if srcnn[model_name] then
  798. local model = srcnn[model_name](backend, ch)
  799. assert(model.w2nn_offset % model.w2nn_scale_factor == 0)
  800. return model
  801. else
  802. error("unsupported model_name: " .. model_name)
  803. end
  804. end
  805. --[[
  806. local model = srcnn.fcn_v1("cunn", 3):cuda()
  807. print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
  808. print(model)
  809. local model = srcnn.unet_refine("cunn", 3):cuda()
  810. print(model)
  811. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
  812. local model = srcnn.cupconv_14("cunn", 3):cuda()
  813. print(model)
  814. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
  815. os.exit()
  816. local model = srcnn.cupconv_14("cunn", 3):cuda()
  817. print(model)
  818. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
  819. os.exit()
  820. local model = srcnn.upconv_refine("cunn", 3):cuda()
  821. print(model)
  822. model:training()
  823. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()))
  824. os.exit()
  825. local model = srcnn.nw2("cunn", 3):cuda()
  826. print(model)
  827. model:training()
  828. print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()))
  829. os.exit()
  830. local model = srcnn.prog_net("cunn", 3):cuda()
  831. print(model)
  832. model:training()
  833. print(model:forward(torch.Tensor(1, 3, 128, 128):zero():cuda()))
  834. os.exit()
  835. local model = srcnn.double_unet("cunn", 3):cuda()
  836. print(model)
  837. model:training()
  838. print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()))
  839. os.exit()
  840. local model = srcnn.cunet_v3("cunn", 3):cuda()
  841. print(model)
  842. model:training()
  843. print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()):size())
  844. os.exit()
  845. local model = srcnn.cunet_v6("cunn", 3):cuda()
  846. print(model)
  847. model:training()
  848. print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()))
  849. os.exit()
  850. --]]
  851. return srcnn