pairwise_transform_utils.lua 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. require 'cunn'
  2. local iproc = require 'iproc'
  3. local gm = {}
  4. gm.Image = require 'graphicsmagick.Image'
  5. local data_augmentation = require 'data_augmentation'
  6. local pairwise_transform_utils = {}
  7. function pairwise_transform_utils.random_half(src, p, filters)
  8. if torch.uniform() < p then
  9. local filter = filters[torch.random(1, #filters)]
  10. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  11. else
  12. return src
  13. end
  14. end
  15. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  16. local tries = 4
  17. if src:size(2) > max_size and src:size(3) > max_size then
  18. assert(max_size % 4 == 0)
  19. local rect
  20. for i = 1, tries do
  21. local yi = torch.random(0, src:size(2) - max_size)
  22. local xi = torch.random(0, src:size(3) - max_size)
  23. if mod then
  24. yi = yi - (yi % mod)
  25. xi = xi - (xi % mod)
  26. end
  27. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  28. -- ignore simple background
  29. if rect:float():std() >= 0 then
  30. break
  31. end
  32. end
  33. return rect
  34. else
  35. return src
  36. end
  37. end
  38. function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
  39. local tries = 4
  40. if y:size(2) > max_size and y:size(3) > max_size then
  41. assert(max_size % 4 == 0)
  42. local rect_x, rect_y
  43. for i = 1, tries do
  44. local yi = torch.random(0, y:size(2) - max_size)
  45. local xi = torch.random(0, y:size(3) - max_size)
  46. if mod then
  47. yi = yi - (yi % mod)
  48. xi = xi - (xi % mod)
  49. end
  50. rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
  51. rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
  52. -- ignore simple background
  53. if rect_y:float():std() >= 0 then
  54. break
  55. end
  56. end
  57. return rect_x, rect_y
  58. else
  59. return x, y
  60. end
  61. end
  62. function pairwise_transform_utils.preprocess(src, crop_size, options)
  63. local dest = src
  64. local box_only = false
  65. if options.data.filters then
  66. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  67. box_only = true
  68. end
  69. end
  70. if box_only then
  71. local mod = 2 -- assert pos % 2 == 0
  72. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  73. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  74. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  75. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  76. dest = iproc.crop_mod4(dest)
  77. else
  78. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  79. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  80. dest = data_augmentation.blur(dest, options.random_blur_rate,
  81. options.random_blur_size,
  82. options.random_blur_sigma_min,
  83. options.random_blur_sigma_max)
  84. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  85. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  86. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  87. dest = data_augmentation.shift_1px(dest)
  88. end
  89. return dest
  90. end
  91. function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
  92. x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
  93. x, y = data_augmentation.pairwise_rotate(x, y,
  94. options.random_pairwise_rotate_rate,
  95. options.random_pairwise_rotate_min,
  96. options.random_pairwise_rotate_max)
  97. local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
  98. local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
  99. x, y = data_augmentation.pairwise_scale(x, y,
  100. options.random_pairwise_scale_rate,
  101. scale_min,
  102. scale_max)
  103. x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
  104. x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
  105. x = iproc.crop_mod4(x)
  106. y = iproc.crop_mod4(y)
  107. if options.pairwise_y_binary then
  108. y[torch.lt(y, 128)] = 0
  109. y[torch.gt(y, 0)] = 255
  110. end
  111. return x, y
  112. end
  113. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  114. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  115. assert("crop_size % scale == 0", size % scale == 0)
  116. local r = torch.uniform()
  117. local t = "float"
  118. if x:type() == "torch.ByteTensor" then
  119. t = "byte"
  120. end
  121. if p < r then
  122. local xi = torch.random(1, x:size(3) - (size + 1)) * scale
  123. local yi = torch.random(1, x:size(2) - (size + 1)) * scale
  124. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  125. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  126. return xc, yc
  127. else
  128. local xcs = torch.LongTensor(tries, y:size(1), size, size)
  129. local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
  130. local rects = {}
  131. local r = torch.LongTensor(2, tries)
  132. r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
  133. r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
  134. for i = 1, tries do
  135. local xi = r[1][i]
  136. local yi = r[2][i]
  137. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  138. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  139. xcs[i]:copy(xc)
  140. lcs[i]:copy(lc)
  141. rects[i] = {xi, yi}
  142. end
  143. xcs:csub(lcs)
  144. xcs:cmul(xcs)
  145. local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
  146. local best_xi = rects[l[1][1]][1]
  147. local best_yi = rects[l[1][1]][2]
  148. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  149. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  150. return xc, yc
  151. end
  152. end
  153. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  154. local xs = {}
  155. local ns = {}
  156. local ys = {}
  157. local ls = {}
  158. for j = 1, 2 do
  159. -- TTA
  160. local xi, yi, ri
  161. if j == 1 then
  162. xi = x
  163. ni = x_noise
  164. yi = y
  165. ri = lowres_y
  166. else
  167. xi = x:transpose(2, 3):contiguous()
  168. if x_noise then
  169. ni = x_noise:transpose(2, 3):contiguous()
  170. end
  171. yi = y:transpose(2, 3):contiguous()
  172. ri = lowres_y:transpose(2, 3):contiguous()
  173. end
  174. local xv = iproc.vflip(xi)
  175. local nv
  176. if x_noise then
  177. nv = iproc.vflip(ni)
  178. end
  179. local yv = iproc.vflip(yi)
  180. local rv = iproc.vflip(ri)
  181. table.insert(xs, xi)
  182. if ni then
  183. table.insert(ns, ni)
  184. end
  185. table.insert(ys, yi)
  186. table.insert(ls, ri)
  187. table.insert(xs, xv)
  188. if nv then
  189. table.insert(ns, nv)
  190. end
  191. table.insert(ys, yv)
  192. table.insert(ls, rv)
  193. table.insert(xs, iproc.hflip(xi))
  194. if ni then
  195. table.insert(ns, iproc.hflip(ni))
  196. end
  197. table.insert(ys, iproc.hflip(yi))
  198. table.insert(ls, iproc.hflip(ri))
  199. table.insert(xs, iproc.hflip(xv))
  200. if nv then
  201. table.insert(ns, iproc.hflip(nv))
  202. end
  203. table.insert(ys, iproc.hflip(yv))
  204. table.insert(ls, iproc.hflip(rv))
  205. end
  206. return xs, ys, ls, ns
  207. end
  208. local function lowres_model()
  209. local seq = nn.Sequential()
  210. seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
  211. seq:add(nn.SpatialUpSamplingNearest(2))
  212. return seq:cuda()
  213. end
  214. local g_lowres_model = nil
  215. local g_lowres_gpu = nil
  216. function pairwise_transform_utils.low_resolution(src)
  217. --[[
  218. -- I am not sure that the following process is thraed-safe
  219. g_lowres_model = g_lowres_model or lowres_model()
  220. if g_lowres_gpu == nil then
  221. --benchmark
  222. local gpu_time = sys.clock()
  223. for i = 1, 10 do
  224. g_lowres_model:forward(src:cuda()):byte()
  225. end
  226. gpu_time = sys.clock() - gpu_time
  227. local cpu_time = sys.clock()
  228. for i = 1, 10 do
  229. gm.Image(src, "RGB", "DHW"):
  230. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  231. size(src:size(3), src:size(2), "Box"):
  232. toTensor("byte", "RGB", "DHW")
  233. end
  234. cpu_time = sys.clock() - cpu_time
  235. --print(gpu_time, cpu_time)
  236. if gpu_time < cpu_time then
  237. g_lowres_gpu = true
  238. else
  239. g_lowres_gpu = false
  240. end
  241. end
  242. if g_lowres_gpu then
  243. return g_lowres_model:forward(src:cuda()):byte()
  244. else
  245. return gm.Image(src, "RGB", "DHW"):
  246. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  247. size(src:size(3), src:size(2), "Box"):
  248. toTensor("byte", "RGB", "DHW")
  249. end
  250. --]]
  251. return gm.Image(src, "RGB", "DHW"):
  252. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  253. size(src:size(3), src:size(2), "Box"):
  254. toTensor("byte", "RGB", "DHW")
  255. end
  256. return pairwise_transform_utils