pairwise_transform_utils.lua 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. require 'cunn'
  2. local iproc = require 'iproc'
  3. local gm = {}
  4. gm.Image = require 'graphicsmagick.Image'
  5. local data_augmentation = require 'data_augmentation'
  6. local pairwise_transform_utils = {}
  7. function pairwise_transform_utils.random_half(src, p, filters)
  8. if torch.uniform() < p then
  9. local filter = filters[torch.random(1, #filters)]
  10. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  11. else
  12. return src
  13. end
  14. end
  15. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  16. local tries = 4
  17. if src:size(2) > max_size and src:size(3) > max_size then
  18. assert(max_size % 4 == 0)
  19. local rect
  20. for i = 1, tries do
  21. local yi = torch.random(0, src:size(2) - max_size)
  22. local xi = torch.random(0, src:size(3) - max_size)
  23. if mod then
  24. yi = yi - (yi % mod)
  25. xi = xi - (xi % mod)
  26. end
  27. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  28. -- ignore simple background
  29. if rect:float():std() >= 0 then
  30. break
  31. end
  32. end
  33. return rect
  34. else
  35. return src
  36. end
  37. end
  38. function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
  39. local tries = 4
  40. if y:size(2) > max_size and y:size(3) > max_size then
  41. assert(max_size % 4 == 0)
  42. local rect_x, rect_y
  43. for i = 1, tries do
  44. local yi = torch.random(0, y:size(2) - max_size)
  45. local xi = torch.random(0, y:size(3) - max_size)
  46. if mod then
  47. yi = yi - (yi % mod)
  48. xi = xi - (xi % mod)
  49. end
  50. rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
  51. rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
  52. -- ignore simple background
  53. if rect_y:float():std() >= 0 then
  54. break
  55. end
  56. end
  57. return rect_x, rect_y
  58. else
  59. return x, y
  60. end
  61. end
  62. function pairwise_transform_utils.preprocess(src, crop_size, options)
  63. local dest = src
  64. local box_only = false
  65. if options.data.filters then
  66. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  67. box_only = true
  68. end
  69. end
  70. if box_only then
  71. local mod = 2 -- assert pos % 2 == 0
  72. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  73. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  74. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  75. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  76. dest = iproc.crop_mod4(dest)
  77. else
  78. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  79. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  80. dest = data_augmentation.blur(dest, options.random_blur_rate,
  81. options.random_blur_size,
  82. options.random_blur_sigma_min,
  83. options.random_blur_sigma_max)
  84. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  85. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  86. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  87. dest = data_augmentation.shift_1px(dest)
  88. end
  89. return dest
  90. end
  91. function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
  92. x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
  93. x, y = data_augmentation.pairwise_rotate(x, y,
  94. options.random_pairwise_rotate_rate,
  95. options.random_pairwise_rotate_min,
  96. options.random_pairwise_rotate_max)
  97. local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
  98. local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
  99. x, y = data_augmentation.pairwise_scale(x, y,
  100. options.random_pairwise_scale_rate,
  101. scale_min,
  102. scale_max)
  103. x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
  104. x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
  105. x = iproc.crop_mod4(x)
  106. y = iproc.crop_mod4(y)
  107. if options.pairwise_y_binary then
  108. y[torch.lt(y, 128)] = 0
  109. y[torch.gt(y, 0)] = 255
  110. end
  111. return x, y
  112. end
  113. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  114. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  115. assert("crop_size % scale == 0", size % scale == 0)
  116. local r = torch.uniform()
  117. local t = "float"
  118. if x:type() == "torch.ByteTensor" then
  119. t = "byte"
  120. end
  121. if p < r then
  122. local xi = torch.random(1, x:size(3) - (size + 1)) * scale
  123. local yi = torch.random(1, x:size(2) - (size + 1)) * scale
  124. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  125. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  126. return xc, yc
  127. else
  128. local xcs = torch.LongTensor(tries, y:size(1), size, size)
  129. local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
  130. local rects = {}
  131. local r = torch.LongTensor(2, tries)
  132. r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
  133. r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
  134. for i = 1, tries do
  135. local xi = r[1][i]
  136. local yi = r[2][i]
  137. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  138. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  139. xcs[i]:copy(xc)
  140. lcs[i]:copy(lc)
  141. rects[i] = {xi, yi}
  142. end
  143. xcs:csub(lcs)
  144. xcs:cmul(xcs)
  145. local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
  146. local best_xi = rects[l[1][1]][1]
  147. local best_yi = rects[l[1][1]][2]
  148. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  149. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  150. return xc, yc
  151. end
  152. end
  153. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  154. local xs = {}
  155. local ns = {}
  156. local ys = {}
  157. local ls = {}
  158. for j = 1, 2 do
  159. -- TTA
  160. local xi, yi, ri, ni
  161. if j == 1 then
  162. xi = x
  163. ni = x_noise
  164. yi = y
  165. ri = lowres_y
  166. else
  167. xi = x:transpose(2, 3):contiguous()
  168. if x_noise then
  169. ni = x_noise:transpose(2, 3):contiguous()
  170. end
  171. yi = y:transpose(2, 3):contiguous()
  172. if lowres_y then
  173. ri = lowres_y:transpose(2, 3):contiguous()
  174. end
  175. end
  176. local xv = iproc.vflip(xi)
  177. local nv
  178. if x_noise then
  179. nv = iproc.vflip(ni)
  180. end
  181. local yv = iproc.vflip(yi)
  182. local rv
  183. if ri then
  184. rv = iproc.vflip(ri)
  185. end
  186. table.insert(xs, xi)
  187. if ni then
  188. table.insert(ns, ni)
  189. end
  190. table.insert(ys, yi)
  191. if ri then
  192. table.insert(ls, ri)
  193. end
  194. table.insert(xs, xv)
  195. if nv then
  196. table.insert(ns, nv)
  197. end
  198. table.insert(ys, yv)
  199. if rv then
  200. table.insert(ls, rv)
  201. end
  202. table.insert(xs, iproc.hflip(xi))
  203. if ni then
  204. table.insert(ns, iproc.hflip(ni))
  205. end
  206. table.insert(ys, iproc.hflip(yi))
  207. if ri then
  208. table.insert(ls, iproc.hflip(ri))
  209. end
  210. table.insert(xs, iproc.hflip(xv))
  211. if nv then
  212. table.insert(ns, iproc.hflip(nv))
  213. end
  214. table.insert(ys, iproc.hflip(yv))
  215. if rv then
  216. table.insert(ls, iproc.hflip(rv))
  217. end
  218. end
  219. return xs, ys, ls, ns
  220. end
  221. local function lowres_model()
  222. local seq = nn.Sequential()
  223. seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
  224. seq:add(nn.SpatialUpSamplingNearest(2))
  225. return seq:cuda()
  226. end
  227. local g_lowres_model = nil
  228. local g_lowres_gpu = nil
  229. function pairwise_transform_utils.low_resolution(src)
  230. --[[
  231. -- I am not sure that the following process is thraed-safe
  232. g_lowres_model = g_lowres_model or lowres_model()
  233. if g_lowres_gpu == nil then
  234. --benchmark
  235. local gpu_time = sys.clock()
  236. for i = 1, 10 do
  237. g_lowres_model:forward(src:cuda()):byte()
  238. end
  239. gpu_time = sys.clock() - gpu_time
  240. local cpu_time = sys.clock()
  241. for i = 1, 10 do
  242. gm.Image(src, "RGB", "DHW"):
  243. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  244. size(src:size(3), src:size(2), "Box"):
  245. toTensor("byte", "RGB", "DHW")
  246. end
  247. cpu_time = sys.clock() - cpu_time
  248. --print(gpu_time, cpu_time)
  249. if gpu_time < cpu_time then
  250. g_lowres_gpu = true
  251. else
  252. g_lowres_gpu = false
  253. end
  254. end
  255. if g_lowres_gpu then
  256. return g_lowres_model:forward(src:cuda()):byte()
  257. else
  258. return gm.Image(src, "RGB", "DHW"):
  259. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  260. size(src:size(3), src:size(2), "Box"):
  261. toTensor("byte", "RGB", "DHW")
  262. end
  263. --]]
  264. return gm.Image(src, "RGB", "DHW"):
  265. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  266. size(src:size(3), src:size(2), "Box"):
  267. toTensor("byte", "RGB", "DHW")
  268. end
  269. return pairwise_transform_utils