pairwise_transform_utils.lua 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. require 'cunn'
  2. local iproc = require 'iproc'
  3. local gm = {}
  4. gm.Image = require 'graphicsmagick.Image'
  5. local data_augmentation = require 'data_augmentation'
  6. local pairwise_transform_utils = {}
  7. function pairwise_transform_utils.random_half(src, p, filters)
  8. if torch.uniform() < p then
  9. local filter = filters[torch.random(1, #filters)]
  10. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  11. else
  12. return src
  13. end
  14. end
  15. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  16. local tries = 4
  17. if src:size(2) > max_size and src:size(3) > max_size then
  18. assert(max_size % 4 == 0)
  19. local rect
  20. for i = 1, tries do
  21. local yi = torch.random(0, src:size(2) - max_size)
  22. local xi = torch.random(0, src:size(3) - max_size)
  23. if mod then
  24. yi = yi - (yi % mod)
  25. xi = xi - (xi % mod)
  26. end
  27. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  28. -- ignore simple background
  29. if rect:float():std() >= 0 then
  30. break
  31. end
  32. end
  33. return rect
  34. else
  35. return src
  36. end
  37. end
  38. function pairwise_transform_utils.preprocess(src, crop_size, options)
  39. local dest = src
  40. local box_only = false
  41. if options.data.filters then
  42. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  43. box_only = true
  44. end
  45. end
  46. if box_only then
  47. local mod = 2 -- assert pos % 2 == 0
  48. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  49. dest = data_augmentation.flip(dest)
  50. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  51. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  52. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  53. dest = iproc.crop_mod4(dest)
  54. else
  55. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  56. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  57. dest = data_augmentation.flip(dest)
  58. dest = data_augmentation.blur(dest, options.random_blur_rate,
  59. options.random_blur_size,
  60. options.random_blur_min,
  61. options.random_blur_max)
  62. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  63. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  64. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  65. dest = data_augmentation.shift_1px(dest)
  66. end
  67. return dest
  68. end
  69. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  70. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  71. assert("crop_size % scale == 0", size % scale == 0)
  72. local r = torch.uniform()
  73. local t = "float"
  74. if x:type() == "torch.ByteTensor" then
  75. t = "byte"
  76. end
  77. if p < r then
  78. local xi = torch.random(1, x:size(3) - (size + 1)) * scale
  79. local yi = torch.random(1, x:size(2) - (size + 1)) * scale
  80. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  81. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  82. return xc, yc
  83. else
  84. local xcs = torch.LongTensor(tries, y:size(1), size, size)
  85. local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
  86. local rects = {}
  87. local r = torch.LongTensor(2, tries)
  88. r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
  89. r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
  90. for i = 1, tries do
  91. local xi = r[1][i]
  92. local yi = r[2][i]
  93. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  94. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  95. xcs[i]:copy(xc)
  96. lcs[i]:copy(lc)
  97. rects[i] = {xi, yi}
  98. end
  99. xcs:csub(lcs)
  100. xcs:cmul(xcs)
  101. local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
  102. local best_xi = rects[l[1][1]][1]
  103. local best_yi = rects[l[1][1]][2]
  104. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  105. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  106. return xc, yc
  107. end
  108. end
  109. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  110. local xs = {}
  111. local ns = {}
  112. local ys = {}
  113. local ls = {}
  114. for j = 1, 2 do
  115. -- TTA
  116. local xi, yi, ri
  117. if j == 1 then
  118. xi = x
  119. ni = x_noise
  120. yi = y
  121. ri = lowres_y
  122. else
  123. xi = x:transpose(2, 3):contiguous()
  124. if x_noise then
  125. ni = x_noise:transpose(2, 3):contiguous()
  126. end
  127. yi = y:transpose(2, 3):contiguous()
  128. ri = lowres_y:transpose(2, 3):contiguous()
  129. end
  130. local xv = iproc.vflip(xi)
  131. local nv
  132. if x_noise then
  133. nv = iproc.vflip(ni)
  134. end
  135. local yv = iproc.vflip(yi)
  136. local rv = iproc.vflip(ri)
  137. table.insert(xs, xi)
  138. if ni then
  139. table.insert(ns, ni)
  140. end
  141. table.insert(ys, yi)
  142. table.insert(ls, ri)
  143. table.insert(xs, xv)
  144. if nv then
  145. table.insert(ns, nv)
  146. end
  147. table.insert(ys, yv)
  148. table.insert(ls, rv)
  149. table.insert(xs, iproc.hflip(xi))
  150. if ni then
  151. table.insert(ns, iproc.hflip(ni))
  152. end
  153. table.insert(ys, iproc.hflip(yi))
  154. table.insert(ls, iproc.hflip(ri))
  155. table.insert(xs, iproc.hflip(xv))
  156. if nv then
  157. table.insert(ns, iproc.hflip(nv))
  158. end
  159. table.insert(ys, iproc.hflip(yv))
  160. table.insert(ls, iproc.hflip(rv))
  161. end
  162. return xs, ys, ls, ns
  163. end
  164. local function lowres_model()
  165. local seq = nn.Sequential()
  166. seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
  167. seq:add(nn.SpatialUpSamplingNearest(2))
  168. return seq:cuda()
  169. end
  170. local g_lowres_model = nil
  171. local g_lowres_gpu = nil
  172. function pairwise_transform_utils.low_resolution(src)
  173. --[[
  174. -- I am not sure that the following process is thraed-safe
  175. g_lowres_model = g_lowres_model or lowres_model()
  176. if g_lowres_gpu == nil then
  177. --benchmark
  178. local gpu_time = sys.clock()
  179. for i = 1, 10 do
  180. g_lowres_model:forward(src:cuda()):byte()
  181. end
  182. gpu_time = sys.clock() - gpu_time
  183. local cpu_time = sys.clock()
  184. for i = 1, 10 do
  185. gm.Image(src, "RGB", "DHW"):
  186. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  187. size(src:size(3), src:size(2), "Box"):
  188. toTensor("byte", "RGB", "DHW")
  189. end
  190. cpu_time = sys.clock() - cpu_time
  191. --print(gpu_time, cpu_time)
  192. if gpu_time < cpu_time then
  193. g_lowres_gpu = true
  194. else
  195. g_lowres_gpu = false
  196. end
  197. end
  198. if g_lowres_gpu then
  199. return g_lowres_model:forward(src:cuda()):byte()
  200. else
  201. return gm.Image(src, "RGB", "DHW"):
  202. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  203. size(src:size(3), src:size(2), "Box"):
  204. toTensor("byte", "RGB", "DHW")
  205. end
  206. --]]
  207. return gm.Image(src, "RGB", "DHW"):
  208. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  209. size(src:size(3), src:size(2), "Box"):
  210. toTensor("byte", "RGB", "DHW")
  211. end
  212. return pairwise_transform_utils