pairwise_transform_utils.lua 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. require 'cunn'
  2. local iproc = require 'iproc'
  3. local gm = {}
  4. gm.Image = require 'graphicsmagick.Image'
  5. local data_augmentation = require 'data_augmentation'
  6. local pairwise_transform_utils = {}
  7. function pairwise_transform_utils.random_half(src, p, filters)
  8. if torch.uniform() < p then
  9. local filter = filters[torch.random(1, #filters)]
  10. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  11. else
  12. return src
  13. end
  14. end
  15. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  16. local tries = 4
  17. if src:size(2) > max_size and src:size(3) > max_size then
  18. assert(max_size % 4 == 0)
  19. local rect
  20. for i = 1, tries do
  21. local yi = torch.random(0, src:size(2) - max_size)
  22. local xi = torch.random(0, src:size(3) - max_size)
  23. if mod then
  24. yi = yi - (yi % mod)
  25. xi = xi - (xi % mod)
  26. end
  27. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  28. -- ignore simple background
  29. if rect:float():std() >= 0 then
  30. break
  31. end
  32. end
  33. return rect
  34. else
  35. return src
  36. end
  37. end
  38. function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
  39. local tries = 4
  40. if y:size(2) > max_size and y:size(3) > max_size then
  41. assert(max_size % 4 == 0)
  42. local rect_x, rect_y
  43. for i = 1, tries do
  44. local yi = torch.random(0, y:size(2) - max_size)
  45. local xi = torch.random(0, y:size(3) - max_size)
  46. if mod then
  47. yi = yi - (yi % mod)
  48. xi = xi - (xi % mod)
  49. end
  50. rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
  51. rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
  52. -- ignore simple background
  53. if rect_y:float():std() >= 0 then
  54. break
  55. end
  56. end
  57. return rect_x, rect_y
  58. else
  59. return x, y
  60. end
  61. end
  62. function pairwise_transform_utils.preprocess(src, crop_size, options)
  63. local dest = src
  64. local box_only = false
  65. if options.data.filters then
  66. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  67. box_only = true
  68. end
  69. end
  70. if box_only then
  71. local mod = 2 -- assert pos % 2 == 0
  72. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  73. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  74. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  75. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  76. dest = iproc.crop_mod4(dest)
  77. else
  78. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  79. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  80. dest = data_augmentation.blur(dest, options.random_blur_rate,
  81. options.random_blur_size,
  82. options.random_blur_sigma_min,
  83. options.random_blur_sigma_max)
  84. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  85. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  86. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  87. dest = data_augmentation.shift_1px(dest)
  88. end
  89. return dest
  90. end
  91. function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
  92. x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
  93. x, y = data_augmentation.pairwise_rotate(x, y,
  94. options.random_pairwise_rotate_rate,
  95. options.random_pairwise_rotate_min,
  96. options.random_pairwise_rotate_max)
  97. local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
  98. local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
  99. x, y = data_augmentation.pairwise_scale(x, y,
  100. options.random_pairwise_scale_rate,
  101. scale_min,
  102. scale_max)
  103. x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
  104. x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
  105. x = iproc.crop_mod4(x)
  106. y = iproc.crop_mod4(y)
  107. if options.pairwise_y_binary then
  108. y[torch.lt(y, 128)] = 0
  109. y[torch.gt(y, 0)] = 255
  110. end
  111. return x, y
  112. end
  113. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  114. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  115. assert("crop_size % scale == 0", size % scale == 0)
  116. local r = torch.uniform()
  117. local t = "float"
  118. if x:type() == "torch.ByteTensor" then
  119. t = "byte"
  120. end
  121. if p < r then
  122. local xi = 0
  123. local yi = 0
  124. if x:size(2) > size + 1 then
  125. xi = torch.random(0, x:size(2) - (size + 1)) * scale
  126. end
  127. if x:size(3) > size + 1 then
  128. yi = torch.random(0, x:size(3) - (size + 1)) * scale
  129. end
  130. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  131. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  132. return xc, yc
  133. else
  134. local xcs = torch.LongTensor(tries, y:size(1), size, size)
  135. local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
  136. local rects = {}
  137. local r = torch.LongTensor(2, tries)
  138. r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
  139. r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
  140. for i = 1, tries do
  141. local xi = r[1][i]
  142. local yi = r[2][i]
  143. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  144. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  145. xcs[i]:copy(xc)
  146. lcs[i]:copy(lc)
  147. rects[i] = {xi, yi}
  148. end
  149. xcs:csub(lcs)
  150. xcs:cmul(xcs)
  151. local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
  152. local best_xi = rects[l[1][1]][1]
  153. local best_yi = rects[l[1][1]][2]
  154. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  155. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  156. return xc, yc
  157. end
  158. end
  159. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  160. local xs = {}
  161. local ns = {}
  162. local ys = {}
  163. local ls = {}
  164. for j = 1, 2 do
  165. -- TTA
  166. local xi, yi, ri, ni
  167. if j == 1 then
  168. xi = x
  169. ni = x_noise
  170. yi = y
  171. ri = lowres_y
  172. else
  173. xi = x:transpose(2, 3):contiguous()
  174. if x_noise then
  175. ni = x_noise:transpose(2, 3):contiguous()
  176. end
  177. yi = y:transpose(2, 3):contiguous()
  178. if lowres_y then
  179. ri = lowres_y:transpose(2, 3):contiguous()
  180. end
  181. end
  182. local xv = iproc.vflip(xi)
  183. local nv
  184. if x_noise then
  185. nv = iproc.vflip(ni)
  186. end
  187. local yv = iproc.vflip(yi)
  188. local rv
  189. if ri then
  190. rv = iproc.vflip(ri)
  191. end
  192. table.insert(xs, xi)
  193. if ni then
  194. table.insert(ns, ni)
  195. end
  196. table.insert(ys, yi)
  197. if ri then
  198. table.insert(ls, ri)
  199. end
  200. table.insert(xs, xv)
  201. if nv then
  202. table.insert(ns, nv)
  203. end
  204. table.insert(ys, yv)
  205. if rv then
  206. table.insert(ls, rv)
  207. end
  208. table.insert(xs, iproc.hflip(xi))
  209. if ni then
  210. table.insert(ns, iproc.hflip(ni))
  211. end
  212. table.insert(ys, iproc.hflip(yi))
  213. if ri then
  214. table.insert(ls, iproc.hflip(ri))
  215. end
  216. table.insert(xs, iproc.hflip(xv))
  217. if nv then
  218. table.insert(ns, iproc.hflip(nv))
  219. end
  220. table.insert(ys, iproc.hflip(yv))
  221. if rv then
  222. table.insert(ls, iproc.hflip(rv))
  223. end
  224. end
  225. return xs, ys, ls, ns
  226. end
  227. local function lowres_model()
  228. local seq = nn.Sequential()
  229. seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
  230. seq:add(nn.SpatialUpSamplingNearest(2))
  231. return seq:cuda()
  232. end
  233. local g_lowres_model = nil
  234. local g_lowres_gpu = nil
  235. function pairwise_transform_utils.low_resolution(src)
  236. --[[
  237. -- I am not sure that the following process is thraed-safe
  238. g_lowres_model = g_lowres_model or lowres_model()
  239. if g_lowres_gpu == nil then
  240. --benchmark
  241. local gpu_time = sys.clock()
  242. for i = 1, 10 do
  243. g_lowres_model:forward(src:cuda()):byte()
  244. end
  245. gpu_time = sys.clock() - gpu_time
  246. local cpu_time = sys.clock()
  247. for i = 1, 10 do
  248. gm.Image(src, "RGB", "DHW"):
  249. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  250. size(src:size(3), src:size(2), "Box"):
  251. toTensor("byte", "RGB", "DHW")
  252. end
  253. cpu_time = sys.clock() - cpu_time
  254. --print(gpu_time, cpu_time)
  255. if gpu_time < cpu_time then
  256. g_lowres_gpu = true
  257. else
  258. g_lowres_gpu = false
  259. end
  260. end
  261. if g_lowres_gpu then
  262. return g_lowres_model:forward(src:cuda()):byte()
  263. else
  264. return gm.Image(src, "RGB", "DHW"):
  265. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  266. size(src:size(3), src:size(2), "Box"):
  267. toTensor("byte", "RGB", "DHW")
  268. end
  269. --]]
  270. return gm.Image(src, "RGB", "DHW"):
  271. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  272. size(src:size(3), src:size(2), "Box"):
  273. toTensor("byte", "RGB", "DHW")
  274. end
  275. return pairwise_transform_utils