pairwise_transform_utils.lua 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. require 'cunn'
  2. local iproc = require 'iproc'
  3. local gm = {}
  4. gm.Image = require 'graphicsmagick.Image'
  5. local data_augmentation = require 'data_augmentation'
  6. local pairwise_transform_utils = {}
  7. function pairwise_transform_utils.random_half(src, p, filters)
  8. if torch.uniform() < p then
  9. local filter = filters[torch.random(1, #filters)]
  10. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  11. else
  12. return src
  13. end
  14. end
  15. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  16. local tries = 4
  17. if src:size(2) > max_size and src:size(3) > max_size then
  18. assert(max_size % 4 == 0)
  19. local rect
  20. for i = 1, tries do
  21. local yi = torch.random(0, src:size(2) - max_size)
  22. local xi = torch.random(0, src:size(3) - max_size)
  23. if mod then
  24. yi = yi - (yi % mod)
  25. xi = xi - (xi % mod)
  26. end
  27. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  28. -- ignore simple background
  29. if rect:float():std() >= 0 then
  30. break
  31. end
  32. end
  33. return rect
  34. else
  35. return src
  36. end
  37. end
  38. function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
  39. local tries = 4
  40. if y:size(2) > max_size and y:size(3) > max_size then
  41. assert(max_size % 4 == 0)
  42. local rect_x, rect_y
  43. for i = 1, tries do
  44. local yi = torch.random(0, y:size(2) - max_size)
  45. local xi = torch.random(0, y:size(3) - max_size)
  46. if mod then
  47. yi = yi - (yi % mod)
  48. xi = xi - (xi % mod)
  49. end
  50. rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
  51. rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
  52. -- ignore simple background
  53. if rect_y:float():std() >= 0 then
  54. break
  55. end
  56. end
  57. return rect_x, rect_y
  58. else
  59. return x, y
  60. end
  61. end
  62. function pairwise_transform_utils.preprocess(src, crop_size, options)
  63. local dest = src
  64. local box_only = false
  65. if options.data.filters then
  66. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  67. box_only = true
  68. end
  69. end
  70. if box_only then
  71. local mod = 2 -- assert pos % 2 == 0
  72. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  73. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  74. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  75. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  76. dest = iproc.crop_mod4(dest)
  77. else
  78. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  79. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  80. dest = data_augmentation.blur(dest, options.random_blur_rate,
  81. options.random_blur_size,
  82. options.random_blur_sigma_min,
  83. options.random_blur_sigma_max)
  84. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  85. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  86. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  87. dest = data_augmentation.shift_1px(dest)
  88. end
  89. return dest
  90. end
  91. function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
  92. x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
  93. x, y = data_augmentation.pairwise_rotate(x, y,
  94. options.random_pairwise_rotate_rate,
  95. options.random_pairwise_rotate_min,
  96. options.random_pairwise_rotate_max)
  97. local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
  98. local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
  99. x, y = data_augmentation.pairwise_scale(x, y,
  100. options.random_pairwise_scale_rate,
  101. scale_min,
  102. scale_max)
  103. x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
  104. x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
  105. x = iproc.crop_mod4(x)
  106. y = iproc.crop_mod4(y)
  107. return x, y
  108. end
  109. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  110. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  111. assert("crop_size % scale == 0", size % scale == 0)
  112. local r = torch.uniform()
  113. local t = "float"
  114. if x:type() == "torch.ByteTensor" then
  115. t = "byte"
  116. end
  117. if p < r then
  118. local xi = 0
  119. local yi = 0
  120. if x:size(3) > size + 1 then
  121. xi = torch.random(0, x:size(3) - (size + 1)) * scale
  122. end
  123. if x:size(2) > size + 1 then
  124. yi = torch.random(0, x:size(2) - (size + 1)) * scale
  125. end
  126. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  127. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  128. return xc, yc
  129. else
  130. local xcs = torch.LongTensor(tries, y:size(1), size, size)
  131. local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
  132. local rects = {}
  133. local r = torch.LongTensor(2, tries)
  134. r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
  135. r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
  136. for i = 1, tries do
  137. local xi = r[1][i]
  138. local yi = r[2][i]
  139. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  140. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  141. xcs[i]:copy(xc)
  142. lcs[i]:copy(lc)
  143. rects[i] = {xi, yi}
  144. end
  145. xcs:csub(lcs)
  146. xcs:cmul(xcs)
  147. local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
  148. local best_xi = rects[l[1][1]][1]
  149. local best_yi = rects[l[1][1]][2]
  150. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  151. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  152. return xc, yc
  153. end
  154. end
  155. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  156. local xs = {}
  157. local ns = {}
  158. local ys = {}
  159. local ls = {}
  160. for j = 1, 2 do
  161. -- TTA
  162. local xi, yi, ri, ni
  163. if j == 1 then
  164. xi = x
  165. ni = x_noise
  166. yi = y
  167. ri = lowres_y
  168. else
  169. xi = x:transpose(2, 3):contiguous()
  170. if x_noise then
  171. ni = x_noise:transpose(2, 3):contiguous()
  172. end
  173. yi = y:transpose(2, 3):contiguous()
  174. if lowres_y then
  175. ri = lowres_y:transpose(2, 3):contiguous()
  176. end
  177. end
  178. local xv = iproc.vflip(xi)
  179. local nv
  180. if x_noise then
  181. nv = iproc.vflip(ni)
  182. end
  183. local yv = iproc.vflip(yi)
  184. local rv
  185. if ri then
  186. rv = iproc.vflip(ri)
  187. end
  188. table.insert(xs, xi)
  189. if ni then
  190. table.insert(ns, ni)
  191. end
  192. table.insert(ys, yi)
  193. if ri then
  194. table.insert(ls, ri)
  195. end
  196. table.insert(xs, xv)
  197. if nv then
  198. table.insert(ns, nv)
  199. end
  200. table.insert(ys, yv)
  201. if rv then
  202. table.insert(ls, rv)
  203. end
  204. table.insert(xs, iproc.hflip(xi))
  205. if ni then
  206. table.insert(ns, iproc.hflip(ni))
  207. end
  208. table.insert(ys, iproc.hflip(yi))
  209. if ri then
  210. table.insert(ls, iproc.hflip(ri))
  211. end
  212. table.insert(xs, iproc.hflip(xv))
  213. if nv then
  214. table.insert(ns, iproc.hflip(nv))
  215. end
  216. table.insert(ys, iproc.hflip(yv))
  217. if rv then
  218. table.insert(ls, iproc.hflip(rv))
  219. end
  220. end
  221. return xs, ys, ls, ns
  222. end
  223. local function lowres_model()
  224. local seq = nn.Sequential()
  225. seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
  226. seq:add(nn.SpatialUpSamplingNearest(2))
  227. return seq:cuda()
  228. end
  229. local g_lowres_model = nil
  230. local g_lowres_gpu = nil
  231. function pairwise_transform_utils.low_resolution(src)
  232. --[[
  233. -- I am not sure that the following process is thraed-safe
  234. g_lowres_model = g_lowres_model or lowres_model()
  235. if g_lowres_gpu == nil then
  236. --benchmark
  237. local gpu_time = sys.clock()
  238. for i = 1, 10 do
  239. g_lowres_model:forward(src:cuda()):byte()
  240. end
  241. gpu_time = sys.clock() - gpu_time
  242. local cpu_time = sys.clock()
  243. for i = 1, 10 do
  244. gm.Image(src, "RGB", "DHW"):
  245. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  246. size(src:size(3), src:size(2), "Box"):
  247. toTensor("byte", "RGB", "DHW")
  248. end
  249. cpu_time = sys.clock() - cpu_time
  250. --print(gpu_time, cpu_time)
  251. if gpu_time < cpu_time then
  252. g_lowres_gpu = true
  253. else
  254. g_lowres_gpu = false
  255. end
  256. end
  257. if g_lowres_gpu then
  258. return g_lowres_model:forward(src:cuda()):byte()
  259. else
  260. return gm.Image(src, "RGB", "DHW"):
  261. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  262. size(src:size(3), src:size(2), "Box"):
  263. toTensor("byte", "RGB", "DHW")
  264. end
  265. --]]
  266. if src:size(1) == 1 then
  267. return gm.Image(src, "I", "DHW"):
  268. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  269. size(src:size(3), src:size(2), "Box"):
  270. toTensor("byte", "I", "DHW")
  271. else
  272. return gm.Image(src, "RGB", "DHW"):
  273. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  274. size(src:size(3), src:size(2), "Box"):
  275. toTensor("byte", "RGB", "DHW")
  276. end
  277. end
  278. return pairwise_transform_utils