srcnn.lua 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056
  1. require 'w2nn'
  2. -- ref: http://arxiv.org/abs/1502.01852
  3. -- ref: http://arxiv.org/abs/1501.00092
  4. local srcnn = {}
  5. local function msra_filler(mod)
  6. local fin = mod.kW * mod.kH * mod.nInputPlane
  7. local fout = mod.kW * mod.kH * mod.nOutputPlane
  8. stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
  9. mod.weight:normal(0, stdv)
  10. mod.bias:zero()
  11. end
  12. local function identity_filler(mod)
  13. assert(mod.nInputPlane <= mod.nOutputPlane)
  14. mod.weight:normal(0, 0.01)
  15. mod.bias:zero()
  16. local num_groups = mod.nInputPlane -- fixed
  17. local filler_value = num_groups / mod.nOutputPlane
  18. local in_group_size = math.floor(mod.nInputPlane / num_groups)
  19. local out_group_size = math.floor(mod.nOutputPlane / num_groups)
  20. local x = math.floor(mod.kW / 2)
  21. local y = math.floor(mod.kH / 2)
  22. for i = 0, num_groups - 1 do
  23. for j = i * out_group_size, (i + 1) * out_group_size - 1 do
  24. for k = i * in_group_size, (i + 1) * in_group_size - 1 do
  25. mod.weight[j+1][k+1][y+1][x+1] = filler_value
  26. end
  27. end
  28. end
  29. end
  30. function nn.SpatialConvolutionMM:reset(stdv)
  31. msra_filler(self)
  32. end
  33. function nn.SpatialFullConvolution:reset(stdv)
  34. msra_filler(self)
  35. end
  36. function nn.SpatialDilatedConvolution:reset(stdv)
  37. identity_filler(self)
  38. end
  39. if cudnn and cudnn.SpatialConvolution then
  40. function cudnn.SpatialConvolution:reset(stdv)
  41. msra_filler(self)
  42. end
  43. function cudnn.SpatialFullConvolution:reset(stdv)
  44. msra_filler(self)
  45. end
  46. if cudnn.SpatialDilatedConvolution then
  47. function cudnn.SpatialDilatedConvolution:reset(stdv)
  48. identity_filler(self)
  49. end
  50. end
  51. end
  52. function nn.SpatialConvolutionMM:clearState()
  53. if self.gradWeight then
  54. self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
  55. end
  56. if self.gradBias then
  57. self.gradBias:resize(self.nOutputPlane):zero()
  58. end
  59. return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
  60. end
  61. function srcnn.channels(model)
  62. if model.w2nn_channels ~= nil then
  63. return model.w2nn_channels
  64. else
  65. return model:get(model:size() - 1).weight:size(1)
  66. end
  67. end
  68. function srcnn.backend(model)
  69. local conv = model:findModules("cudnn.SpatialConvolution")
  70. local fullconv = model:findModules("cudnn.SpatialFullConvolution")
  71. if #conv > 0 or #fullconv > 0 then
  72. return "cudnn"
  73. else
  74. return "cunn"
  75. end
  76. end
  77. function srcnn.color(model)
  78. local ch = srcnn.channels(model)
  79. if ch == 3 then
  80. return "rgb"
  81. else
  82. return "y"
  83. end
  84. end
  85. function srcnn.name(model)
  86. if model.w2nn_arch_name ~= nil then
  87. return model.w2nn_arch_name
  88. else
  89. local conv = model:findModules("nn.SpatialConvolutionMM")
  90. if #conv == 0 then
  91. conv = model:findModules("cudnn.SpatialConvolution")
  92. end
  93. if #conv == 7 then
  94. return "vgg_7"
  95. elseif #conv == 12 then
  96. return "vgg_12"
  97. else
  98. error("unsupported model")
  99. end
  100. end
  101. end
  102. function srcnn.offset_size(model)
  103. if model.w2nn_offset ~= nil then
  104. return model.w2nn_offset
  105. else
  106. local name = srcnn.name(model)
  107. if name:match("vgg_") then
  108. local conv = model:findModules("nn.SpatialConvolutionMM")
  109. if #conv == 0 then
  110. conv = model:findModules("cudnn.SpatialConvolution")
  111. end
  112. local offset = 0
  113. for i = 1, #conv do
  114. offset = offset + (conv[i].kW - 1) / 2
  115. end
  116. return math.floor(offset)
  117. else
  118. error("unsupported model")
  119. end
  120. end
  121. end
  122. function srcnn.scale_factor(model)
  123. if model.w2nn_scale_factor ~= nil then
  124. return model.w2nn_scale_factor
  125. else
  126. local name = srcnn.name(model)
  127. if name == "upconv_7" then
  128. return 2
  129. elseif name == "upconv_8_4x" then
  130. return 4
  131. else
  132. return 1
  133. end
  134. end
  135. end
  136. local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  137. if backend == "cunn" then
  138. return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  139. elseif backend == "cudnn" then
  140. return cudnn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  141. else
  142. error("unsupported backend:" .. backend)
  143. end
  144. end
  145. srcnn.SpatialConvolution = SpatialConvolution
  146. local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  147. if backend == "cunn" then
  148. return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, adjW, adjH)
  149. elseif backend == "cudnn" then
  150. return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
  151. else
  152. error("unsupported backend:" .. backend)
  153. end
  154. end
  155. srcnn.SpatialFullConvolution = SpatialFullConvolution
  156. local function ReLU(backend)
  157. if backend == "cunn" then
  158. return nn.ReLU(true)
  159. elseif backend == "cudnn" then
  160. return cudnn.ReLU(true)
  161. else
  162. error("unsupported backend:" .. backend)
  163. end
  164. end
  165. srcnn.ReLU = ReLU
  166. local function Sigmoid(backend)
  167. if backend == "cunn" then
  168. return nn.Sigmoid(true)
  169. elseif backend == "cudnn" then
  170. return cudnn.Sigmoid(true)
  171. else
  172. error("unsupported backend:" .. backend)
  173. end
  174. end
  175. srcnn.ReLU = ReLU
  176. local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
  177. if backend == "cunn" then
  178. return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  179. elseif backend == "cudnn" then
  180. return cudnn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
  181. else
  182. error("unsupported backend:" .. backend)
  183. end
  184. end
  185. srcnn.SpatialMaxPooling = SpatialMaxPooling
  186. local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
  187. if backend == "cunn" then
  188. return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  189. elseif backend == "cudnn" then
  190. return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
  191. else
  192. error("unsupported backend:" .. backend)
  193. end
  194. end
  195. srcnn.SpatialAveragePooling = SpatialAveragePooling
  196. local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  197. if backend == "cunn" then
  198. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  199. elseif backend == "cudnn" then
  200. if cudnn.SpatialDilatedConvolution then
  201. -- cudnn v 6
  202. return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  203. else
  204. return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
  205. end
  206. else
  207. error("unsupported backend:" .. backend)
  208. end
  209. end
  210. srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
  211. local function GlobalAveragePooling(n_output)
  212. local gap = nn.Sequential()
  213. gap:add(nn.Mean(-1, -1)):add(nn.Mean(-1, -1))
  214. gap:add(nn.View(-1, n_output, 1, 1))
  215. return gap
  216. end
  217. srcnn.GlobalAveragePooling = GlobalAveragePooling
  218. -- VGG style net(7 layers)
  219. function srcnn.vgg_7(backend, ch)
  220. local model = nn.Sequential()
  221. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  222. model:add(nn.LeakyReLU(0.1, true))
  223. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  224. model:add(nn.LeakyReLU(0.1, true))
  225. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  226. model:add(nn.LeakyReLU(0.1, true))
  227. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  228. model:add(nn.LeakyReLU(0.1, true))
  229. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  230. model:add(nn.LeakyReLU(0.1, true))
  231. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  232. model:add(nn.LeakyReLU(0.1, true))
  233. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  234. model:add(w2nn.InplaceClip01())
  235. model:add(nn.View(-1):setNumInputDims(3))
  236. model.w2nn_arch_name = "vgg_7"
  237. model.w2nn_offset = 7
  238. model.w2nn_scale_factor = 1
  239. model.w2nn_channels = ch
  240. --model:cuda()
  241. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  242. return model
  243. end
  244. -- VGG style net(12 layers)
  245. function srcnn.vgg_12(backend, ch)
  246. local model = nn.Sequential()
  247. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  248. model:add(nn.LeakyReLU(0.1, true))
  249. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  250. model:add(nn.LeakyReLU(0.1, true))
  251. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  252. model:add(nn.LeakyReLU(0.1, true))
  253. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  254. model:add(nn.LeakyReLU(0.1, true))
  255. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  256. model:add(nn.LeakyReLU(0.1, true))
  257. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  258. model:add(nn.LeakyReLU(0.1, true))
  259. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  260. model:add(nn.LeakyReLU(0.1, true))
  261. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  262. model:add(nn.LeakyReLU(0.1, true))
  263. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  264. model:add(nn.LeakyReLU(0.1, true))
  265. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  266. model:add(nn.LeakyReLU(0.1, true))
  267. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  268. model:add(nn.LeakyReLU(0.1, true))
  269. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  270. model:add(w2nn.InplaceClip01())
  271. model:add(nn.View(-1):setNumInputDims(3))
  272. model.w2nn_arch_name = "vgg_12"
  273. model.w2nn_offset = 12
  274. model.w2nn_scale_factor = 1
  275. model.w2nn_resize = false
  276. model.w2nn_channels = ch
  277. --model:cuda()
  278. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  279. return model
  280. end
  281. -- Dilated Convolution (7 layers)
  282. function srcnn.dilated_7(backend, ch)
  283. local model = nn.Sequential()
  284. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  285. model:add(nn.LeakyReLU(0.1, true))
  286. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  287. model:add(nn.LeakyReLU(0.1, true))
  288. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  289. model:add(nn.LeakyReLU(0.1, true))
  290. model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  291. model:add(nn.LeakyReLU(0.1, true))
  292. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
  293. model:add(nn.LeakyReLU(0.1, true))
  294. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  295. model:add(nn.LeakyReLU(0.1, true))
  296. model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
  297. model:add(w2nn.InplaceClip01())
  298. model:add(nn.View(-1):setNumInputDims(3))
  299. model.w2nn_arch_name = "dilated_7"
  300. model.w2nn_offset = 12
  301. model.w2nn_scale_factor = 1
  302. model.w2nn_resize = false
  303. model.w2nn_channels = ch
  304. --model:cuda()
  305. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  306. return model
  307. end
  308. -- Upconvolution
  309. function srcnn.upconv_7(backend, ch)
  310. local model = nn.Sequential()
  311. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  312. model:add(nn.LeakyReLU(0.1, true))
  313. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  314. model:add(nn.LeakyReLU(0.1, true))
  315. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  316. model:add(nn.LeakyReLU(0.1, true))
  317. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  318. model:add(nn.LeakyReLU(0.1, true))
  319. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  320. model:add(nn.LeakyReLU(0.1, true))
  321. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  322. model:add(nn.LeakyReLU(0.1, true))
  323. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  324. model:add(w2nn.InplaceClip01())
  325. model:add(nn.View(-1):setNumInputDims(3))
  326. model.w2nn_arch_name = "upconv_7"
  327. model.w2nn_offset = 14
  328. model.w2nn_scale_factor = 2
  329. model.w2nn_resize = true
  330. model.w2nn_channels = ch
  331. return model
  332. end
  333. -- large version of upconv_7
  334. -- This model able to beat upconv_7 (PSNR: +0.3 ~ +0.8) but this model is 2x slower than upconv_7.
  335. function srcnn.upconv_7l(backend, ch)
  336. local model = nn.Sequential()
  337. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  338. model:add(nn.LeakyReLU(0.1, true))
  339. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  340. model:add(nn.LeakyReLU(0.1, true))
  341. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  342. model:add(nn.LeakyReLU(0.1, true))
  343. model:add(SpatialConvolution(backend, 128, 192, 3, 3, 1, 1, 0, 0))
  344. model:add(nn.LeakyReLU(0.1, true))
  345. model:add(SpatialConvolution(backend, 192, 256, 3, 3, 1, 1, 0, 0))
  346. model:add(nn.LeakyReLU(0.1, true))
  347. model:add(SpatialConvolution(backend, 256, 512, 3, 3, 1, 1, 0, 0))
  348. model:add(nn.LeakyReLU(0.1, true))
  349. model:add(SpatialFullConvolution(backend, 512, ch, 4, 4, 2, 2, 3, 3):noBias())
  350. model:add(w2nn.InplaceClip01())
  351. model:add(nn.View(-1):setNumInputDims(3))
  352. model.w2nn_arch_name = "upconv_7l"
  353. model.w2nn_offset = 14
  354. model.w2nn_scale_factor = 2
  355. model.w2nn_resize = true
  356. model.w2nn_channels = ch
  357. --model:cuda()
  358. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  359. return model
  360. end
  361. -- layerwise linear blending with skip connections
  362. -- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l
  363. function srcnn.skiplb_7(backend, ch)
  364. local function skip(backend, i, o)
  365. local con = nn.Concat(2)
  366. local conv = nn.Sequential()
  367. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1))
  368. conv:add(nn.LeakyReLU(0.1, true))
  369. -- depth concat
  370. con:add(conv)
  371. con:add(nn.Identity()) -- skip
  372. return con
  373. end
  374. local model = nn.Sequential()
  375. model:add(skip(backend, ch, 16))
  376. model:add(skip(backend, 16+ch, 32))
  377. model:add(skip(backend, 32+16+ch, 64))
  378. model:add(skip(backend, 64+32+16+ch, 128))
  379. model:add(skip(backend, 128+64+32+16+ch, 128))
  380. model:add(skip(backend, 128+128+64+32+16+ch, 256))
  381. -- input of last layer = [all layerwise output(contains input layer)].flatten
  382. model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend
  383. model:add(w2nn.InplaceClip01())
  384. model:add(nn.View(-1):setNumInputDims(3))
  385. model.w2nn_arch_name = "skiplb_7"
  386. model.w2nn_offset = 14
  387. model.w2nn_scale_factor = 2
  388. model.w2nn_resize = true
  389. model.w2nn_channels = ch
  390. --model:cuda()
  391. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  392. return model
  393. end
  394. -- dilated convolution + deconvolution
  395. -- Note: This model is not better than upconv_7. Maybe becuase of under-fitting.
  396. function srcnn.dilated_upconv_7(backend, ch)
  397. local model = nn.Sequential()
  398. model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0))
  399. model:add(nn.LeakyReLU(0.1, true))
  400. model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0))
  401. model:add(nn.LeakyReLU(0.1, true))
  402. model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
  403. model:add(nn.LeakyReLU(0.1, true))
  404. model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  405. model:add(nn.LeakyReLU(0.1, true))
  406. model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2))
  407. model:add(nn.LeakyReLU(0.1, true))
  408. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  409. model:add(nn.LeakyReLU(0.1, true))
  410. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  411. model:add(w2nn.InplaceClip01())
  412. model:add(nn.View(-1):setNumInputDims(3))
  413. model.w2nn_arch_name = "dilated_upconv_7"
  414. model.w2nn_offset = 20
  415. model.w2nn_scale_factor = 2
  416. model.w2nn_resize = true
  417. model.w2nn_channels = ch
  418. --model:cuda()
  419. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  420. return model
  421. end
  422. -- ref: https://arxiv.org/abs/1609.04802
  423. -- note: no batch-norm, no zero-paading
  424. function srcnn.srresnet_2x(backend, ch)
  425. local function resblock(backend)
  426. local seq = nn.Sequential()
  427. local con = nn.ConcatTable()
  428. local conv = nn.Sequential()
  429. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  430. conv:add(ReLU(backend))
  431. conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  432. conv:add(ReLU(backend))
  433. con:add(conv)
  434. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  435. seq:add(con)
  436. seq:add(nn.CAddTable())
  437. return seq
  438. end
  439. local model = nn.Sequential()
  440. --model:add(skip(backend, ch, 64 - ch))
  441. model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0))
  442. model:add(nn.LeakyReLU(0.1, true))
  443. model:add(resblock(backend))
  444. model:add(resblock(backend))
  445. model:add(resblock(backend))
  446. model:add(resblock(backend))
  447. model:add(resblock(backend))
  448. model:add(resblock(backend))
  449. model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2))
  450. model:add(ReLU(backend))
  451. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  452. model:add(w2nn.InplaceClip01())
  453. --model:add(nn.View(-1):setNumInputDims(3))
  454. model.w2nn_arch_name = "srresnet_2x"
  455. model.w2nn_offset = 28
  456. model.w2nn_scale_factor = 2
  457. model.w2nn_resize = true
  458. model.w2nn_channels = ch
  459. --model:cuda()
  460. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  461. return model
  462. end
  463. -- large version of srresnet_2x. It's current best model but slow.
  464. function srcnn.resnet_14l(backend, ch)
  465. local function resblock(backend, i, o)
  466. local seq = nn.Sequential()
  467. local con = nn.ConcatTable()
  468. local conv = nn.Sequential()
  469. conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 0, 0))
  470. conv:add(nn.LeakyReLU(0.1, true))
  471. conv:add(SpatialConvolution(backend, o, o, 3, 3, 1, 1, 0, 0))
  472. conv:add(nn.LeakyReLU(0.1, true))
  473. con:add(conv)
  474. if i == o then
  475. con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding
  476. else
  477. local seq = nn.Sequential()
  478. seq:add(SpatialConvolution(backend, i, o, 1, 1, 1, 1, 0, 0))
  479. seq:add(nn.SpatialZeroPadding(-2, -2, -2, -2))
  480. con:add(seq)
  481. end
  482. seq:add(con)
  483. seq:add(nn.CAddTable())
  484. return seq
  485. end
  486. local model = nn.Sequential()
  487. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  488. model:add(nn.LeakyReLU(0.1, true))
  489. model:add(resblock(backend, 32, 64))
  490. model:add(resblock(backend, 64, 64))
  491. model:add(resblock(backend, 64, 128))
  492. model:add(resblock(backend, 128, 128))
  493. model:add(resblock(backend, 128, 256))
  494. model:add(resblock(backend, 256, 256))
  495. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  496. model:add(w2nn.InplaceClip01())
  497. model:add(nn.View(-1):setNumInputDims(3))
  498. model.w2nn_arch_name = "resnet_14l"
  499. model.w2nn_offset = 28
  500. model.w2nn_scale_factor = 2
  501. model.w2nn_resize = true
  502. model.w2nn_channels = ch
  503. --model:cuda()
  504. --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
  505. return model
  506. end
  507. -- for segmentation
  508. function srcnn.fcn_v1(backend, ch)
  509. -- input_size = 120
  510. local model = nn.Sequential()
  511. --i = 120
  512. --model:cuda()
  513. --print(model:forward(torch.Tensor(32, ch, i, i):uniform():cuda()):size())
  514. model:add(SpatialConvolution(backend, ch, 32, 5, 5, 2, 2, 0, 0))
  515. model:add(nn.LeakyReLU(0.1, true))
  516. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  517. model:add(nn.LeakyReLU(0.1, true))
  518. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  519. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  520. model:add(nn.LeakyReLU(0.1, true))
  521. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  522. model:add(nn.LeakyReLU(0.1, true))
  523. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  524. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  525. model:add(nn.LeakyReLU(0.1, true))
  526. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  527. model:add(nn.LeakyReLU(0.1, true))
  528. model:add(SpatialMaxPooling(backend, 2, 2, 2, 2))
  529. model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
  530. model:add(nn.LeakyReLU(0.1, true))
  531. model:add(nn.Dropout(0.5, false, true))
  532. model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
  533. model:add(nn.LeakyReLU(0.1, true))
  534. model:add(SpatialFullConvolution(backend, 128, 128, 2, 2, 2, 2, 0, 0))
  535. model:add(nn.LeakyReLU(0.1, true))
  536. model:add(SpatialConvolution(backend, 128, 64, 3, 3, 1, 1, 0, 0))
  537. model:add(nn.LeakyReLU(0.1, true))
  538. model:add(SpatialFullConvolution(backend, 64, 64, 2, 2, 2, 2, 0, 0))
  539. model:add(nn.LeakyReLU(0.1, true))
  540. model:add(SpatialConvolution(backend, 64, 32, 3, 3, 1, 1, 0, 0))
  541. model:add(nn.LeakyReLU(0.1, true))
  542. model:add(SpatialFullConvolution(backend, 32, ch, 4, 4, 2, 2, 3, 3))
  543. model:add(w2nn.InplaceClip01())
  544. model:add(nn.View(-1):setNumInputDims(3))
  545. model.w2nn_arch_name = "fcn_v1"
  546. model.w2nn_offset = 36
  547. model.w2nn_scale_factor = 1
  548. model.w2nn_channels = ch
  549. model.w2nn_input_size = 120
  550. --model.w2nn_gcn = true
  551. return model
  552. end
  553. function srcnn.cupconv_14(backend, ch)
  554. local function skip(backend, n_input, n_output, pad)
  555. local con = nn.ConcatTable()
  556. local conv = nn.Sequential()
  557. local depad = nn.Sequential()
  558. conv:add(nn.SelectTable(1))
  559. conv:add(SpatialConvolution(backend, n_input, n_output, 3, 3, 1, 1, 0, 0))
  560. conv:add(nn.LeakyReLU(0.1, true))
  561. con:add(conv)
  562. con:add(nn.Identity())
  563. return con
  564. end
  565. local function concat(backend, n, ch, n_middle)
  566. local con = nn.ConcatTable()
  567. for i = 1, n do
  568. local pad = i - 1
  569. if i == 1 then
  570. con:add(nn.Sequential():add(nn.SelectTable(i)))
  571. else
  572. local seq = nn.Sequential()
  573. seq:add(nn.SelectTable(i))
  574. if pad > 0 then
  575. seq:add(nn.SpatialZeroPadding(-pad, -pad, -pad, -pad))
  576. end
  577. if i == n then
  578. --seq:add(SpatialConvolution(backend, ch, n_middle, 1, 1, 1, 1, 0, 0))
  579. else
  580. seq:add(w2nn.GradWeight(0.025))
  581. seq:add(SpatialConvolution(backend, n_middle, n_middle, 1, 1, 1, 1, 0, 0))
  582. end
  583. seq:add(nn.LeakyReLU(0.1, true))
  584. con:add(seq)
  585. end
  586. end
  587. return nn.Sequential():add(con):add(nn.JoinTable(2))
  588. end
  589. local model = nn.Sequential()
  590. local m = 64
  591. local n = 14
  592. model:add(nn.ConcatTable():add(nn.Identity()))
  593. for i = 1, n - 1 do
  594. if i == 1 then
  595. model:add(skip(backend, ch, m))
  596. else
  597. model:add(skip(backend, m, m))
  598. end
  599. end
  600. model:add(nn.FlattenTable())
  601. model:add(concat(backend, n, ch, m))
  602. model:add(SpatialFullConvolution(backend, m * (n - 1) + 3, ch, 4, 4, 2, 2, 3, 3):noBias())
  603. model:add(w2nn.InplaceClip01())
  604. model:add(nn.View(-1):setNumInputDims(3))
  605. model.w2nn_arch_name = "cupconv_14"
  606. model.w2nn_offset = 28
  607. model.w2nn_scale_factor = 2
  608. model.w2nn_channels = ch
  609. model.w2nn_resize = true
  610. return model
  611. end
  612. function srcnn.upconv_refine(backend, ch)
  613. local function block(backend, ch)
  614. local seq = nn.Sequential()
  615. local con = nn.ConcatTable()
  616. local res = nn.Sequential()
  617. local base = nn.Sequential()
  618. local refine = nn.Sequential()
  619. local aux_con = nn.ConcatTable()
  620. res:add(w2nn.GradWeight(0.1))
  621. res:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  622. res:add(nn.LeakyReLU(0.1, true))
  623. res:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  624. res:add(nn.LeakyReLU(0.1, true))
  625. res:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  626. res:add(nn.LeakyReLU(0.1, true))
  627. res:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0):noBias())
  628. res:add(w2nn.InplaceClip01())
  629. res:add(nn.MulConstant(0.5))
  630. con:add(res)
  631. con:add(nn.Sequential():add(nn.SpatialZeroPadding(-4, -4, -4, -4)):add(nn.MulConstant(0.5)))
  632. -- main output
  633. refine:add(nn.CAddTable()) -- averaging
  634. refine:add(nn.View(-1):setNumInputDims(3))
  635. -- aux output
  636. base:add(nn.SelectTable(2))
  637. base:add(nn.MulConstant(2)) -- revert mul 0.5
  638. base:add(nn.View(-1):setNumInputDims(3))
  639. aux_con:add(refine)
  640. aux_con:add(base)
  641. seq:add(con)
  642. seq:add(aux_con)
  643. seq:add(w2nn.AuxiliaryLossTable(1))
  644. return seq
  645. end
  646. local model = nn.Sequential()
  647. model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
  648. model:add(nn.LeakyReLU(0.1, true))
  649. model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
  650. model:add(nn.LeakyReLU(0.1, true))
  651. model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
  652. model:add(nn.LeakyReLU(0.1, true))
  653. model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
  654. model:add(nn.LeakyReLU(0.1, true))
  655. model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
  656. model:add(nn.LeakyReLU(0.1, true))
  657. model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0))
  658. model:add(nn.LeakyReLU(0.1, true))
  659. model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias())
  660. model:add(w2nn.InplaceClip01())
  661. model:add(block(backend, ch))
  662. model.w2nn_arch_name = "upconv_refine"
  663. model.w2nn_offset = 18
  664. model.w2nn_scale_factor = 2
  665. model.w2nn_resize = true
  666. model.w2nn_channels = ch
  667. return model
  668. end
  669. -- I devised this arch because of the block size and global average pooling problem,
  670. -- but SEBlock may possibly learn multi-scale input and no problems occur.
  671. local function SpatialSEBlock(backend, ave_size, n_output, r)
  672. local con = nn.ConcatTable(2)
  673. local attention = nn.Sequential()
  674. local n_mid = math.floor(n_output / r)
  675. attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size))
  676. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  677. attention:add(nn.ReLU(true))
  678. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  679. attention:add(nn.Sigmoid(true))
  680. attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size))
  681. con:add(nn.Identity())
  682. con:add(attention)
  683. return con
  684. end
  685. -- Squeeze and Excitation Block
  686. local function SEBlock(backend, n_output, r)
  687. local con = nn.ConcatTable(2)
  688. local attention = nn.Sequential()
  689. local n_mid = math.floor(n_output / r)
  690. attention:add(GlobalAveragePooling(n_output))
  691. attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
  692. attention:add(nn.ReLU(true))
  693. attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
  694. attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
  695. con:add(nn.Identity())
  696. con:add(attention)
  697. return con
  698. end
  699. -- cascaded residual channel attention unet
  700. function srcnn.upcunet(backend, ch)
  701. function unet_branch(insert, backend, n_input, n_output, depad)
  702. local block = nn.Sequential()
  703. local con = nn.ConcatTable(2)
  704. local model = nn.Sequential()
  705. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  706. block:add(insert)
  707. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  708. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  709. con:add(block)
  710. model:add(con)
  711. model:add(nn.CAddTable())
  712. return model
  713. end
  714. function unet_conv(n_input, n_middle, n_output, se)
  715. local model = nn.Sequential()
  716. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  717. model:add(nn.LeakyReLU(0.1, true))
  718. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  719. model:add(nn.LeakyReLU(0.1, true))
  720. if se then
  721. model:add(SEBlock(backend, n_output, 4))
  722. model:add(w2nn.ScaleTable())
  723. end
  724. return model
  725. end
  726. -- Residual U-Net
  727. function unet(backend, ch, deconv)
  728. local block1 = unet_conv(128, 256, 128, true)
  729. local block2 = nn.Sequential()
  730. block2:add(unet_conv(64, 64, 128, true))
  731. block2:add(unet_branch(block1, backend, 128, 128, 4))
  732. block2:add(unet_conv(128, 64, 64, true))
  733. local model = nn.Sequential()
  734. model:add(unet_conv(ch, 32, 64, false))
  735. model:add(unet_branch(block2, backend, 64, 64, 16))
  736. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  737. model:add(nn.LeakyReLU(0.1))
  738. if deconv then
  739. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  740. else
  741. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  742. end
  743. return model
  744. end
  745. local model = nn.Sequential()
  746. local con = nn.ConcatTable()
  747. local aux_con = nn.ConcatTable()
  748. -- 2 cascade
  749. model:add(unet(backend, ch, true))
  750. con:add(unet(backend, ch, false))
  751. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  752. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  753. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  754. model:add(con)
  755. model:add(aux_con)
  756. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  757. model.w2nn_arch_name = "upcunet"
  758. model.w2nn_offset = 60
  759. model.w2nn_scale_factor = 2
  760. model.w2nn_channels = ch
  761. model.w2nn_resize = true
  762. model.w2nn_valid_input_size = {}
  763. for i = 76, 512, 4 do
  764. table.insert(model.w2nn_valid_input_size, i)
  765. end
  766. return model
  767. end
  768. -- cascaded residual spatial channel attention unet
  769. function srcnn.upcunet_v2(backend, ch)
  770. function unet_branch(insert, backend, n_input, n_output, depad)
  771. local block = nn.Sequential()
  772. local con = nn.ConcatTable(2)
  773. local model = nn.Sequential()
  774. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  775. block:add(insert)
  776. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  777. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  778. con:add(block)
  779. model:add(con)
  780. model:add(nn.CAddTable())
  781. return model
  782. end
  783. function unet_conv(n_input, n_middle, n_output, se)
  784. local model = nn.Sequential()
  785. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  786. model:add(nn.LeakyReLU(0.1, true))
  787. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  788. model:add(nn.LeakyReLU(0.1, true))
  789. if se then
  790. model:add(SpatialSEBlock(backend, 4, n_output, 4))
  791. model:add(nn.CMulTable())
  792. end
  793. return model
  794. end
  795. -- Residual U-Net
  796. function unet(backend, in_ch, out_ch, deconv)
  797. local block1 = unet_conv(128, 256, 128, true)
  798. local block2 = nn.Sequential()
  799. block2:add(unet_conv(64, 64, 128, true))
  800. block2:add(unet_branch(block1, backend, 128, 128, 4))
  801. block2:add(unet_conv(128, 64, 64, true))
  802. local model = nn.Sequential()
  803. model:add(unet_conv(in_ch, 32, 64, false))
  804. model:add(unet_branch(block2, backend, 64, 64, 16))
  805. if deconv then
  806. model:add(SpatialFullConvolution(backend, 64, out_ch, 4, 4, 2, 2, 3, 3):noBias())
  807. else
  808. model:add(SpatialConvolution(backend, 64, out_ch, 3, 3, 1, 1, 0, 0):noBias())
  809. end
  810. return model
  811. end
  812. local model = nn.Sequential()
  813. local con = nn.ConcatTable()
  814. local aux_con = nn.ConcatTable()
  815. -- 2 cascade
  816. model:add(unet(backend, ch, ch, true))
  817. con:add(nn.Sequential():add(unet(backend, ch, ch, false)):add(nn.SpatialZeroPadding(-1, -1, -1, -1))) -- -1 for odd output size
  818. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  819. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  820. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  821. model:add(con)
  822. model:add(aux_con)
  823. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  824. model.w2nn_arch_name = "upcunet_v2"
  825. model.w2nn_offset = 58
  826. model.w2nn_scale_factor = 2
  827. model.w2nn_channels = ch
  828. model.w2nn_resize = true
  829. -- {76,92,108,140} are also valid size but it is too small
  830. model.w2nn_valid_input_size = {156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508}
  831. return model
  832. end
  833. -- cascaded residual channel attention unet
  834. function srcnn.upcunet_v3(backend, ch)
  835. local function unet_branch(insert, backend, n_input, n_output, depad)
  836. local block = nn.Sequential()
  837. local con = nn.ConcatTable(2)
  838. local model = nn.Sequential()
  839. block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
  840. block:add(nn.LeakyReLU(0.1, true))
  841. block:add(insert)
  842. block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
  843. block:add(nn.LeakyReLU(0.1, true))
  844. con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
  845. con:add(block)
  846. model:add(con)
  847. model:add(nn.CAddTable())
  848. return model
  849. end
  850. local function unet_conv(n_input, n_middle, n_output, se)
  851. local model = nn.Sequential()
  852. model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
  853. model:add(nn.LeakyReLU(0.1, true))
  854. model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
  855. model:add(nn.LeakyReLU(0.1, true))
  856. if se then
  857. model:add(SEBlock(backend, n_output, 4))
  858. model:add(w2nn.ScaleTable())
  859. end
  860. return model
  861. end
  862. -- Residual U-Net
  863. local function unet(backend, ch, deconv)
  864. local block1 = unet_conv(128, 256, 128, true)
  865. local block2 = nn.Sequential()
  866. block2:add(unet_conv(64, 64, 128, true))
  867. block2:add(unet_branch(block1, backend, 128, 128, 4))
  868. block2:add(unet_conv(128, 64, 64, true))
  869. local model = nn.Sequential()
  870. model:add(unet_conv(ch, 32, 64, false))
  871. model:add(unet_branch(block2, backend, 64, 64, 16))
  872. model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
  873. model:add(nn.LeakyReLU(0.1))
  874. if deconv then
  875. model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
  876. else
  877. model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
  878. end
  879. return model
  880. end
  881. local model = nn.Sequential()
  882. local con = nn.ConcatTable()
  883. local aux_con = nn.ConcatTable()
  884. -- 2 cascade
  885. model:add(unet(backend, ch, true))
  886. con:add(unet(backend, ch, false))
  887. con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
  888. aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
  889. aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
  890. model:add(con)
  891. model:add(aux_con)
  892. model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
  893. model.w2nn_arch_name = "upcunet_v3"
  894. model.w2nn_offset = 60
  895. model.w2nn_scale_factor = 2
  896. model.w2nn_channels = ch
  897. model.w2nn_resize = true
  898. model.w2nn_valid_input_size = {}
  899. for i = 76, 512, 4 do
  900. table.insert(model.w2nn_valid_input_size, i)
  901. end
  902. return model
  903. end
  904. local function bench()
  905. local sys = require 'sys'
  906. cudnn.benchmark = true
  907. local model = nil
  908. local arch = {"upconv_7", "upcunet", "upcunet_v3"}
  909. local backend = "cudnn"
  910. for k = 1, #arch do
  911. model = srcnn[arch[k]](backend, 3):cuda()
  912. model:evaluate()
  913. local dummy = nil
  914. -- warn
  915. for i = 1, 20 do
  916. local x = torch.Tensor(4, 3, 172, 172):uniform():cuda()
  917. model:forward(x)
  918. end
  919. t = sys.clock()
  920. for i = 1, 20 do
  921. local x = torch.Tensor(4, 3, 172, 172):uniform():cuda()
  922. local z = model:forward(x)
  923. if dummy == nil then
  924. dummy = z:clone()
  925. else
  926. dummy:add(z)
  927. end
  928. end
  929. print(arch[k], sys.clock() - t)
  930. model:clearState()
  931. end
  932. end
  933. function srcnn.create(model_name, backend, color)
  934. model_name = model_name or "vgg_7"
  935. backend = backend or "cunn"
  936. color = color or "rgb"
  937. local ch = 3
  938. if color == "rgb" then
  939. ch = 3
  940. elseif color == "y" then
  941. ch = 1
  942. else
  943. error("unsupported color: " .. color)
  944. end
  945. if srcnn[model_name] then
  946. local model = srcnn[model_name](backend, ch)
  947. assert(model.w2nn_offset % model.w2nn_scale_factor == 0)
  948. return model
  949. else
  950. error("unsupported model_name: " .. model_name)
  951. end
  952. end
  953. --[[
  954. local model = srcnn.cunet_v3("cunn", 3):cuda()
  955. print(model)
  956. model:training()
  957. print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()):size())
  958. local model = srcnn.upcunet_v2("cunn", 3):cuda()
  959. print(model)
  960. model:training()
  961. print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda()))
  962. os.exit()
  963. local model = srcnn.upcunet_v3("cunn", 3):cuda()
  964. print(model)
  965. model:training()
  966. print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda()))
  967. os.exit()
  968. bench()
  969. --]]
  970. return srcnn