pairwise_transform_utils.lua 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. require 'cunn'
  2. local iproc = require 'iproc'
  3. local gm = {}
  4. gm.Image = require 'graphicsmagick.Image'
  5. local data_augmentation = require 'data_augmentation'
  6. local pairwise_transform_utils = {}
  7. function pairwise_transform_utils.random_half(src, p, filters)
  8. if torch.uniform() < p then
  9. local filter = filters[torch.random(1, #filters)]
  10. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  11. else
  12. return src
  13. end
  14. end
  15. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  16. local tries = 4
  17. if src:size(2) > max_size and src:size(3) > max_size then
  18. assert(max_size % 4 == 0)
  19. local rect
  20. for i = 1, tries do
  21. local yi = torch.random(0, src:size(2) - max_size)
  22. local xi = torch.random(0, src:size(3) - max_size)
  23. if mod then
  24. yi = yi - (yi % mod)
  25. xi = xi - (xi % mod)
  26. end
  27. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  28. -- ignore simple background
  29. if rect:float():std() >= 0 then
  30. break
  31. end
  32. end
  33. return rect
  34. else
  35. return src
  36. end
  37. end
  38. function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
  39. local tries = 4
  40. if y:size(2) > max_size and y:size(3) > max_size then
  41. assert(max_size % 4 == 0)
  42. local rect_x, rect_y
  43. for i = 1, tries do
  44. local yi = torch.random(0, y:size(2) - max_size)
  45. local xi = torch.random(0, y:size(3) - max_size)
  46. if mod then
  47. yi = yi - (yi % mod)
  48. xi = xi - (xi % mod)
  49. end
  50. rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
  51. rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
  52. -- ignore simple background
  53. if rect_y:float():std() >= 0 then
  54. break
  55. end
  56. end
  57. return rect_x, rect_y
  58. else
  59. return x, y
  60. end
  61. end
  62. function pairwise_transform_utils.preprocess(src, crop_size, options)
  63. local dest = src
  64. local box_only = false
  65. if options.data.filters then
  66. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  67. box_only = true
  68. end
  69. end
  70. if box_only then
  71. local mod = 2 -- assert pos % 2 == 0
  72. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  73. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  74. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  75. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  76. dest = iproc.crop_mod4(dest)
  77. else
  78. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  79. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  80. dest = data_augmentation.blur(dest, options.random_blur_rate,
  81. options.random_blur_size,
  82. options.random_blur_sigma_min,
  83. options.random_blur_sigma_max)
  84. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  85. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  86. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  87. dest = data_augmentation.shift_1px(dest)
  88. end
  89. return dest
  90. end
  91. function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
  92. x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
  93. x, y = data_augmentation.pairwise_rotate(x, y,
  94. options.random_pairwise_rotate_rate,
  95. options.random_pairwise_rotate_min,
  96. options.random_pairwise_rotate_max)
  97. local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
  98. local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
  99. x, y = data_augmentation.pairwise_scale(x, y,
  100. options.random_pairwise_scale_rate,
  101. scale_min,
  102. scale_max)
  103. x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
  104. x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
  105. x = data_augmentation.erase(x,
  106. options.random_erasing_rate,
  107. options.random_erasing_n,
  108. options.random_erasing_rect_min,
  109. options.random_erasing_rect_max)
  110. x = iproc.crop_mod4(x)
  111. y = iproc.crop_mod4(y)
  112. return x, y
  113. end
  114. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  115. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  116. assert("crop_size % scale == 0", size % scale == 0)
  117. local r = torch.uniform()
  118. local t = "float"
  119. if x:type() == "torch.ByteTensor" then
  120. t = "byte"
  121. end
  122. if p < r then
  123. local xi = 0
  124. local yi = 0
  125. if x:size(3) > size + 1 then
  126. xi = torch.random(0, x:size(3) - (size + 1)) * scale
  127. end
  128. if x:size(2) > size + 1 then
  129. yi = torch.random(0, x:size(2) - (size + 1)) * scale
  130. end
  131. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  132. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  133. return xc, yc
  134. else
  135. local xcs = torch.LongTensor(tries, y:size(1), size, size)
  136. local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
  137. local rects = {}
  138. local r = torch.LongTensor(2, tries)
  139. r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
  140. r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
  141. for i = 1, tries do
  142. local xi = r[1][i]
  143. local yi = r[2][i]
  144. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  145. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  146. xcs[i]:copy(xc)
  147. lcs[i]:copy(lc)
  148. rects[i] = {xi, yi}
  149. end
  150. xcs:csub(lcs)
  151. xcs:cmul(xcs)
  152. local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
  153. local best_xi = rects[l[1][1]][1]
  154. local best_yi = rects[l[1][1]][2]
  155. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  156. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  157. return xc, yc
  158. end
  159. end
  160. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  161. local xs = {}
  162. local ns = {}
  163. local ys = {}
  164. local ls = {}
  165. for j = 1, 2 do
  166. -- TTA
  167. local xi, yi, ri, ni
  168. if j == 1 then
  169. xi = x
  170. ni = x_noise
  171. yi = y
  172. ri = lowres_y
  173. else
  174. xi = x:transpose(2, 3):contiguous()
  175. if x_noise then
  176. ni = x_noise:transpose(2, 3):contiguous()
  177. end
  178. yi = y:transpose(2, 3):contiguous()
  179. if lowres_y then
  180. ri = lowres_y:transpose(2, 3):contiguous()
  181. end
  182. end
  183. local xv = iproc.vflip(xi)
  184. local nv
  185. if x_noise then
  186. nv = iproc.vflip(ni)
  187. end
  188. local yv = iproc.vflip(yi)
  189. local rv
  190. if ri then
  191. rv = iproc.vflip(ri)
  192. end
  193. table.insert(xs, xi)
  194. if ni then
  195. table.insert(ns, ni)
  196. end
  197. table.insert(ys, yi)
  198. if ri then
  199. table.insert(ls, ri)
  200. end
  201. table.insert(xs, xv)
  202. if nv then
  203. table.insert(ns, nv)
  204. end
  205. table.insert(ys, yv)
  206. if rv then
  207. table.insert(ls, rv)
  208. end
  209. table.insert(xs, iproc.hflip(xi))
  210. if ni then
  211. table.insert(ns, iproc.hflip(ni))
  212. end
  213. table.insert(ys, iproc.hflip(yi))
  214. if ri then
  215. table.insert(ls, iproc.hflip(ri))
  216. end
  217. table.insert(xs, iproc.hflip(xv))
  218. if nv then
  219. table.insert(ns, iproc.hflip(nv))
  220. end
  221. table.insert(ys, iproc.hflip(yv))
  222. if rv then
  223. table.insert(ls, iproc.hflip(rv))
  224. end
  225. end
  226. return xs, ys, ls, ns
  227. end
  228. local function lowres_model()
  229. local seq = nn.Sequential()
  230. seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
  231. seq:add(nn.SpatialUpSamplingNearest(2))
  232. return seq:cuda()
  233. end
  234. local g_lowres_model = nil
  235. local g_lowres_gpu = nil
  236. function pairwise_transform_utils.low_resolution(src)
  237. --[[
  238. -- I am not sure that the following process is thraed-safe
  239. g_lowres_model = g_lowres_model or lowres_model()
  240. if g_lowres_gpu == nil then
  241. --benchmark
  242. local gpu_time = sys.clock()
  243. for i = 1, 10 do
  244. g_lowres_model:forward(src:cuda()):byte()
  245. end
  246. gpu_time = sys.clock() - gpu_time
  247. local cpu_time = sys.clock()
  248. for i = 1, 10 do
  249. gm.Image(src, "RGB", "DHW"):
  250. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  251. size(src:size(3), src:size(2), "Box"):
  252. toTensor("byte", "RGB", "DHW")
  253. end
  254. cpu_time = sys.clock() - cpu_time
  255. --print(gpu_time, cpu_time)
  256. if gpu_time < cpu_time then
  257. g_lowres_gpu = true
  258. else
  259. g_lowres_gpu = false
  260. end
  261. end
  262. if g_lowres_gpu then
  263. return g_lowres_model:forward(src:cuda()):byte()
  264. else
  265. return gm.Image(src, "RGB", "DHW"):
  266. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  267. size(src:size(3), src:size(2), "Box"):
  268. toTensor("byte", "RGB", "DHW")
  269. end
  270. --]]
  271. if src:size(1) == 1 then
  272. return gm.Image(src, "I", "DHW"):
  273. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  274. size(src:size(3), src:size(2), "Box"):
  275. toTensor("byte", "I", "DHW")
  276. else
  277. return gm.Image(src, "RGB", "DHW"):
  278. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  279. size(src:size(3), src:size(2), "Box"):
  280. toTensor("byte", "RGB", "DHW")
  281. end
  282. end
  283. return pairwise_transform_utils