reconstruct.lua 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. require 'image'
  2. local iproc = require 'iproc'
  3. local srcnn = require 'srcnn'
  4. local function reconstruct_y(model, x, offset, block_size)
  5. if x:dim() == 2 then
  6. x = x:reshape(1, x:size(1), x:size(2))
  7. end
  8. local new_x = torch.Tensor():resizeAs(x):zero()
  9. local output_size = block_size - offset * 2
  10. local input = torch.CudaTensor(1, 1, block_size, block_size)
  11. for i = 1, x:size(2), output_size do
  12. for j = 1, x:size(3), output_size do
  13. if i + block_size - 1 <= x:size(2) and j + block_size - 1 <= x:size(3) then
  14. local index = {{},
  15. {i, i + block_size - 1},
  16. {j, j + block_size - 1}}
  17. input:copy(x[index])
  18. local output = model:forward(input):view(1, output_size, output_size)
  19. local output_index = {{},
  20. {i + offset, offset + i + output_size - 1},
  21. {offset + j, offset + j + output_size - 1}}
  22. new_x[output_index]:copy(output)
  23. end
  24. end
  25. end
  26. return new_x
  27. end
  28. local function reconstruct_rgb(model, x, offset, block_size)
  29. local new_x = torch.Tensor():resizeAs(x):zero()
  30. local output_size = block_size - offset * 2
  31. local input = torch.CudaTensor(1, 3, block_size, block_size)
  32. for i = 1, x:size(2), output_size do
  33. for j = 1, x:size(3), output_size do
  34. if i + block_size - 1 <= x:size(2) and j + block_size - 1 <= x:size(3) then
  35. local index = {{},
  36. {i, i + block_size - 1},
  37. {j, j + block_size - 1}}
  38. input:copy(x[index])
  39. local output = model:forward(input):view(3, output_size, output_size)
  40. local output_index = {{},
  41. {i + offset, offset + i + output_size - 1},
  42. {offset + j, offset + j + output_size - 1}}
  43. new_x[output_index]:copy(output)
  44. end
  45. end
  46. end
  47. return new_x
  48. end
  49. local reconstruct = {}
  50. function reconstruct.is_rgb(model)
  51. if srcnn.channels(model) == 3 then
  52. -- 3ch RGB
  53. return true
  54. else
  55. -- 1ch Y
  56. return false
  57. end
  58. end
  59. function reconstruct.offset_size(model)
  60. return srcnn.offset_size(model)
  61. end
  62. function reconstruct.image_y(model, x, offset, block_size)
  63. block_size = block_size or 128
  64. local output_size = block_size - offset * 2
  65. local h_blocks = math.floor(x:size(2) / output_size) +
  66. ((x:size(2) % output_size == 0 and 0) or 1)
  67. local w_blocks = math.floor(x:size(3) / output_size) +
  68. ((x:size(3) % output_size == 0 and 0) or 1)
  69. local h = offset + h_blocks * output_size + offset
  70. local w = offset + w_blocks * output_size + offset
  71. local pad_h1 = offset
  72. local pad_w1 = offset
  73. local pad_h2 = (h - offset) - x:size(2)
  74. local pad_w2 = (w - offset) - x:size(3)
  75. x = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
  76. local y = reconstruct_y(model, x[1], offset, block_size)
  77. y[torch.lt(y, 0)] = 0
  78. y[torch.gt(y, 1)] = 1
  79. x[1]:copy(y)
  80. local output = image.yuv2rgb(iproc.crop(x,
  81. pad_w1, pad_h1,
  82. x:size(3) - pad_w2, x:size(2) - pad_h2))
  83. output[torch.lt(output, 0)] = 0
  84. output[torch.gt(output, 1)] = 1
  85. x = nil
  86. y = nil
  87. collectgarbage()
  88. return output
  89. end
  90. function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
  91. upsampling_filter = upsampling_filter or "Box"
  92. block_size = block_size or 128
  93. local x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
  94. x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
  95. if x:size(2) * x:size(3) > 2048*2048 then
  96. collectgarbage()
  97. end
  98. local output_size = block_size - offset * 2
  99. local h_blocks = math.floor(x:size(2) / output_size) +
  100. ((x:size(2) % output_size == 0 and 0) or 1)
  101. local w_blocks = math.floor(x:size(3) / output_size) +
  102. ((x:size(3) % output_size == 0 and 0) or 1)
  103. local h = offset + h_blocks * output_size + offset
  104. local w = offset + w_blocks * output_size + offset
  105. local pad_h1 = offset
  106. local pad_w1 = offset
  107. local pad_h2 = (h - offset) - x:size(2)
  108. local pad_w2 = (w - offset) - x:size(3)
  109. x = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
  110. x_lanczos = image.rgb2yuv(iproc.padding(x_lanczos, pad_w1, pad_w2, pad_h1, pad_h2))
  111. local y = reconstruct_y(model, x[1], offset, block_size)
  112. y[torch.lt(y, 0)] = 0
  113. y[torch.gt(y, 1)] = 1
  114. x_lanczos[1]:copy(y)
  115. local output = image.yuv2rgb(iproc.crop(x_lanczos,
  116. pad_w1, pad_h1,
  117. x_lanczos:size(3) - pad_w2, x_lanczos:size(2) - pad_h2))
  118. output[torch.lt(output, 0)] = 0
  119. output[torch.gt(output, 1)] = 1
  120. x = nil
  121. x_lanczos = nil
  122. y = nil
  123. collectgarbage()
  124. return output
  125. end
  126. function reconstruct.image_rgb(model, x, offset, block_size)
  127. block_size = block_size or 128
  128. local output_size = block_size - offset * 2
  129. local h_blocks = math.floor(x:size(2) / output_size) +
  130. ((x:size(2) % output_size == 0 and 0) or 1)
  131. local w_blocks = math.floor(x:size(3) / output_size) +
  132. ((x:size(3) % output_size == 0 and 0) or 1)
  133. local h = offset + h_blocks * output_size + offset
  134. local w = offset + w_blocks * output_size + offset
  135. local pad_h1 = offset
  136. local pad_w1 = offset
  137. local pad_h2 = (h - offset) - x:size(2)
  138. local pad_w2 = (w - offset) - x:size(3)
  139. x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
  140. if x:size(2) * x:size(3) > 2048*2048 then
  141. collectgarbage()
  142. end
  143. local y = reconstruct_rgb(model, x, offset, block_size)
  144. local output = iproc.crop(y,
  145. pad_w1, pad_h1,
  146. y:size(3) - pad_w2, y:size(2) - pad_h2)
  147. output[torch.lt(output, 0)] = 0
  148. output[torch.gt(output, 1)] = 1
  149. x = nil
  150. y = nil
  151. collectgarbage()
  152. return output
  153. end
  154. function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
  155. upsampling_filter = upsampling_filter or "Box"
  156. block_size = block_size or 128
  157. x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
  158. if x:size(2) * x:size(3) > 2048*2048 then
  159. collectgarbage()
  160. end
  161. local output_size = block_size - offset * 2
  162. local h_blocks = math.floor(x:size(2) / output_size) +
  163. ((x:size(2) % output_size == 0 and 0) or 1)
  164. local w_blocks = math.floor(x:size(3) / output_size) +
  165. ((x:size(3) % output_size == 0 and 0) or 1)
  166. local h = offset + h_blocks * output_size + offset
  167. local w = offset + w_blocks * output_size + offset
  168. local pad_h1 = offset
  169. local pad_w1 = offset
  170. local pad_h2 = (h - offset) - x:size(2)
  171. local pad_w2 = (w - offset) - x:size(3)
  172. x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
  173. if x:size(2) * x:size(3) > 2048*2048 then
  174. collectgarbage()
  175. end
  176. local y = reconstruct_rgb(model, x, offset, block_size)
  177. local output = iproc.crop(y,
  178. pad_w1, pad_h1,
  179. y:size(3) - pad_w2, y:size(2) - pad_h2)
  180. output[torch.lt(output, 0)] = 0
  181. output[torch.gt(output, 1)] = 1
  182. x = nil
  183. y = nil
  184. collectgarbage()
  185. return output
  186. end
  187. function reconstruct.image(model, x, block_size)
  188. local i2rgb = false
  189. if x:size(1) == 1 then
  190. local new_x = torch.Tensor(3, x:size(2), x:size(3))
  191. new_x[1]:copy(x)
  192. new_x[2]:copy(x)
  193. new_x[3]:copy(x)
  194. x = new_x
  195. i2rgb = true
  196. end
  197. if reconstruct.is_rgb(model) then
  198. x = reconstruct.image_rgb(model, x,
  199. reconstruct.offset_size(model), block_size)
  200. else
  201. x = reconstruct.image_y(model, x,
  202. reconstruct.offset_size(model), block_size)
  203. end
  204. if i2rgb then
  205. x = image.rgb2y(x)
  206. end
  207. return x
  208. end
  209. function reconstruct.scale(model, scale, x, block_size, upsampling_filter)
  210. local i2rgb = false
  211. if x:size(1) == 1 then
  212. local new_x = torch.Tensor(3, x:size(2), x:size(3))
  213. new_x[1]:copy(x)
  214. new_x[2]:copy(x)
  215. new_x[3]:copy(x)
  216. x = new_x
  217. i2rgb = true
  218. end
  219. if reconstruct.is_rgb(model) then
  220. x = reconstruct.scale_rgb(model, scale, x,
  221. reconstruct.offset_size(model),
  222. block_size,
  223. upsampling_filter)
  224. else
  225. x = reconstruct.scale_y(model, scale, x,
  226. reconstruct.offset_size(model),
  227. block_size,
  228. upsampling_filter)
  229. end
  230. if i2rgb then
  231. x = image.rgb2y(x)
  232. end
  233. return x
  234. end
  235. local function tta(f, model, x, block_size)
  236. local average = nil
  237. local offset = reconstruct.offset_size(model)
  238. for i = 1, 4 do
  239. local flip_f, iflip_f
  240. if i == 1 then
  241. flip_f = function (a) return a end
  242. iflip_f = function (a) return a end
  243. elseif i == 2 then
  244. flip_f = image.vflip
  245. iflip_f = image.vflip
  246. elseif i == 3 then
  247. flip_f = image.hflip
  248. iflip_f = image.hflip
  249. elseif i == 4 then
  250. flip_f = function (a) return image.hflip(image.vflip(a)) end
  251. iflip_f = function (a) return image.vflip(image.hflip(a)) end
  252. end
  253. for j = 1, 2 do
  254. local tr_f, itr_f
  255. if j == 1 then
  256. tr_f = function (a) return a end
  257. itr_f = function (a) return a end
  258. elseif j == 2 then
  259. tr_f = function(a) return a:transpose(2, 3):contiguous() end
  260. itr_f = function(a) return a:transpose(2, 3):contiguous() end
  261. end
  262. local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
  263. offset, block_size)))
  264. if not average then
  265. average = out
  266. else
  267. average:add(out)
  268. end
  269. end
  270. end
  271. return average:div(8.0)
  272. end
  273. function reconstruct.image_tta(model, x, block_size)
  274. if reconstruct.is_rgb(model) then
  275. return tta(reconstruct.image_rgb, model, x, block_size)
  276. else
  277. return tta(reconstruct.image_y, model, x, block_size)
  278. end
  279. end
  280. function reconstruct.scale_tta(model, scale, x, block_size, upsampling_filter)
  281. if reconstruct.is_rgb(model) then
  282. local f = function (model, x, offset, block_size)
  283. return reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
  284. end
  285. return tta(f, model, x, block_size)
  286. else
  287. local f = function (model, x, offset, block_size)
  288. return reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
  289. end
  290. return tta(f, model, x, block_size)
  291. end
  292. end
  293. return reconstruct