pairwise_transform_utils.lua 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. require 'image'
  2. local iproc = require 'iproc'
  3. local data_augmentation = require 'data_augmentation'
  4. local pairwise_transform_utils = {}
  5. function pairwise_transform_utils.random_half(src, p, filters)
  6. if torch.uniform() < p then
  7. local filter = filters[torch.random(1, #filters)]
  8. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  9. else
  10. return src
  11. end
  12. end
  13. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  14. local tries = 4
  15. if src:size(2) > max_size and src:size(3) > max_size then
  16. assert(max_size % 4 == 0)
  17. local rect
  18. for i = 1, tries do
  19. local yi = torch.random(0, src:size(2) - max_size)
  20. local xi = torch.random(0, src:size(3) - max_size)
  21. if mod then
  22. yi = yi - (yi % mod)
  23. xi = xi - (xi % mod)
  24. end
  25. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  26. -- ignore simple background
  27. if rect:float():std() >= 0 then
  28. break
  29. end
  30. end
  31. return rect
  32. else
  33. return src
  34. end
  35. end
  36. function pairwise_transform_utils.preprocess(src, crop_size, options)
  37. local dest = src
  38. local box_only = false
  39. if options.data.filters then
  40. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  41. box_only = true
  42. end
  43. end
  44. if box_only then
  45. local mod = 2 -- assert pos % 2 == 0
  46. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  47. dest = data_augmentation.flip(dest)
  48. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  49. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  50. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  51. dest = iproc.crop_mod4(dest)
  52. else
  53. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  54. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  55. dest = data_augmentation.flip(dest)
  56. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  57. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  58. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  59. dest = data_augmentation.shift_1px(dest)
  60. end
  61. return dest
  62. end
  63. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  64. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  65. assert("crop_size % scale == 0", size % scale == 0)
  66. local r = torch.uniform()
  67. local t = "float"
  68. if x:type() == "torch.ByteTensor" then
  69. t = "byte"
  70. end
  71. if p < r then
  72. local xi = torch.random(1, x:size(3) - (size + 1)) * scale
  73. local yi = torch.random(1, x:size(2) - (size + 1)) * scale
  74. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  75. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  76. return xc, yc
  77. else
  78. local best_se = 0.0
  79. local best_xi, best_yi
  80. local m = torch.LongTensor(y:size(1), size, size)
  81. local targets = {}
  82. for i = 1, tries do
  83. local xi = torch.random(1, x:size(3) - (size + 1)) * scale
  84. local yi = torch.random(1, x:size(2) - (size + 1)) * scale
  85. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  86. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  87. m:copy(xc:long()):csub(lc:long())
  88. m:cmul(m)
  89. local se = m:sum()
  90. if se >= best_se then
  91. best_xi = xi
  92. best_yi = yi
  93. best_se = se
  94. end
  95. end
  96. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  97. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  98. return xc, yc
  99. end
  100. end
  101. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  102. local xs = {}
  103. local ns = {}
  104. local ys = {}
  105. local ls = {}
  106. for j = 1, 2 do
  107. -- TTA
  108. local xi, yi, ri
  109. if j == 1 then
  110. xi = x
  111. ni = x_noise
  112. yi = y
  113. ri = lowres_y
  114. else
  115. xi = x:transpose(2, 3):contiguous()
  116. if x_noise then
  117. ni = x_noise:transpose(2, 3):contiguous()
  118. end
  119. yi = y:transpose(2, 3):contiguous()
  120. ri = lowres_y:transpose(2, 3):contiguous()
  121. end
  122. local xv = image.vflip(xi)
  123. local nv
  124. if x_noise then
  125. nv = image.vflip(ni)
  126. end
  127. local yv = image.vflip(yi)
  128. local rv = image.vflip(ri)
  129. table.insert(xs, xi)
  130. if ni then
  131. table.insert(ns, ni)
  132. end
  133. table.insert(ys, yi)
  134. table.insert(ls, ri)
  135. table.insert(xs, xv)
  136. if nv then
  137. table.insert(ns, nv)
  138. end
  139. table.insert(ys, yv)
  140. table.insert(ls, rv)
  141. table.insert(xs, image.hflip(xi))
  142. if ni then
  143. table.insert(ns, image.hflip(ni))
  144. end
  145. table.insert(ys, image.hflip(yi))
  146. table.insert(ls, image.hflip(ri))
  147. table.insert(xs, image.hflip(xv))
  148. if nv then
  149. table.insert(ns, image.hflip(nv))
  150. end
  151. table.insert(ys, image.hflip(yv))
  152. table.insert(ls, image.hflip(rv))
  153. end
  154. return xs, ys, ls, ns
  155. end
  156. return pairwise_transform_utils