pairwise_transform_utils.lua 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. require 'image'
  2. local iproc = require 'iproc'
  3. local data_augmentation = require 'data_augmentation'
  4. local pairwise_transform_utils = {}
  5. function pairwise_transform_utils.random_half(src, p, filters)
  6. if torch.uniform() < p then
  7. local filter = filters[torch.random(1, #filters)]
  8. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  9. else
  10. return src
  11. end
  12. end
  13. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  14. local tries = 4
  15. if src:size(2) > max_size and src:size(3) > max_size then
  16. assert(max_size % 4 == 0)
  17. local rect
  18. for i = 1, tries do
  19. local yi = torch.random(0, src:size(2) - max_size)
  20. local xi = torch.random(0, src:size(3) - max_size)
  21. if mod then
  22. yi = yi - (yi % mod)
  23. xi = xi - (xi % mod)
  24. end
  25. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  26. -- ignore simple background
  27. if rect:float():std() >= 0 then
  28. break
  29. end
  30. end
  31. return rect
  32. else
  33. return src
  34. end
  35. end
  36. function pairwise_transform_utils.preprocess(src, crop_size, options)
  37. local dest = src
  38. local box_only = false
  39. if options.data.filters then
  40. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  41. box_only = true
  42. end
  43. end
  44. if box_only then
  45. local mod = 2 -- assert pos % 2 == 0
  46. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  47. dest = data_augmentation.flip(dest)
  48. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  49. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  50. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  51. else
  52. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  53. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  54. dest = data_augmentation.flip(dest)
  55. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  56. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  57. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  58. dest = data_augmentation.shift_1px(dest)
  59. end
  60. return dest
  61. end
  62. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  63. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  64. assert("crop_size % scale == 0", size % scale == 0)
  65. local r = torch.uniform()
  66. local t = "float"
  67. if x:type() == "torch.ByteTensor" then
  68. t = "byte"
  69. end
  70. if p < r then
  71. local xi = torch.random(1, x:size(3) - (size + 1)) * scale
  72. local yi = torch.random(1, x:size(2) - (size + 1)) * scale
  73. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  74. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  75. return xc, yc
  76. else
  77. local best_se = 0.0
  78. local best_xi, best_yi
  79. local m = torch.LongTensor(y:size(1), size, size)
  80. local targets = {}
  81. for i = 1, tries do
  82. local xi = torch.random(1, x:size(3) - (size + 1)) * scale
  83. local yi = torch.random(1, x:size(2) - (size + 1)) * scale
  84. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  85. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  86. m:copy(xc:long()):csub(lc:long())
  87. m:cmul(m)
  88. local se = m:sum()
  89. if se >= best_se then
  90. best_xi = xi
  91. best_yi = yi
  92. best_se = se
  93. end
  94. end
  95. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  96. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  97. return xc, yc
  98. end
  99. end
  100. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  101. local xs = {}
  102. local ns = {}
  103. local ys = {}
  104. local ls = {}
  105. for j = 1, 2 do
  106. -- TTA
  107. local xi, yi, ri
  108. if j == 1 then
  109. xi = x
  110. ni = x_noise
  111. yi = y
  112. ri = lowres_y
  113. else
  114. xi = x:transpose(2, 3):contiguous()
  115. if x_noise then
  116. ni = x_noise:transpose(2, 3):contiguous()
  117. end
  118. yi = y:transpose(2, 3):contiguous()
  119. ri = lowres_y:transpose(2, 3):contiguous()
  120. end
  121. local xv = image.vflip(xi)
  122. local nv
  123. if x_noise then
  124. nv = image.vflip(ni)
  125. end
  126. local yv = image.vflip(yi)
  127. local rv = image.vflip(ri)
  128. table.insert(xs, xi)
  129. if ni then
  130. table.insert(ns, ni)
  131. end
  132. table.insert(ys, yi)
  133. table.insert(ls, ri)
  134. table.insert(xs, xv)
  135. if nv then
  136. table.insert(ns, nv)
  137. end
  138. table.insert(ys, yv)
  139. table.insert(ls, rv)
  140. table.insert(xs, image.hflip(xi))
  141. if ni then
  142. table.insert(ns, image.hflip(ni))
  143. end
  144. table.insert(ys, image.hflip(yi))
  145. table.insert(ls, image.hflip(ri))
  146. table.insert(xs, image.hflip(xv))
  147. if nv then
  148. table.insert(ns, image.hflip(nv))
  149. end
  150. table.insert(ys, image.hflip(yv))
  151. table.insert(ls, image.hflip(rv))
  152. end
  153. return xs, ys, ls, ns
  154. end
  155. return pairwise_transform_utils