pairwise_transform_utils.lua 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. require 'cunn'
  2. local iproc = require 'iproc'
  3. local gm = {}
  4. gm.Image = require 'graphicsmagick.Image'
  5. local data_augmentation = require 'data_augmentation'
  6. local pairwise_transform_utils = {}
  7. function pairwise_transform_utils.random_half(src, p, filters)
  8. if torch.uniform() < p then
  9. local filter = filters[torch.random(1, #filters)]
  10. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  11. else
  12. return src
  13. end
  14. end
  15. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  16. local tries = 4
  17. if src:size(2) > max_size and src:size(3) > max_size then
  18. assert(max_size % 4 == 0)
  19. local rect
  20. for i = 1, tries do
  21. local yi = torch.random(0, src:size(2) - max_size)
  22. local xi = torch.random(0, src:size(3) - max_size)
  23. if mod then
  24. yi = yi - (yi % mod)
  25. xi = xi - (xi % mod)
  26. end
  27. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  28. -- ignore simple background
  29. if rect:float():std() >= 0 then
  30. break
  31. end
  32. end
  33. return rect
  34. else
  35. return src
  36. end
  37. end
  38. function pairwise_transform_utils.preprocess(src, crop_size, options)
  39. local dest = src
  40. local box_only = false
  41. if options.data.filters then
  42. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  43. box_only = true
  44. end
  45. end
  46. if box_only then
  47. local mod = 2 -- assert pos % 2 == 0
  48. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  49. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  50. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  51. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  52. dest = iproc.crop_mod4(dest)
  53. else
  54. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  55. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  56. dest = data_augmentation.blur(dest, options.random_blur_rate,
  57. options.random_blur_size,
  58. options.random_blur_min,
  59. options.random_blur_max)
  60. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  61. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  62. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  63. dest = data_augmentation.shift_1px(dest)
  64. end
  65. return dest
  66. end
  67. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  68. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  69. assert("crop_size % scale == 0", size % scale == 0)
  70. local r = torch.uniform()
  71. local t = "float"
  72. if x:type() == "torch.ByteTensor" then
  73. t = "byte"
  74. end
  75. if p < r then
  76. local xi = torch.random(1, x:size(3) - (size + 1)) * scale
  77. local yi = torch.random(1, x:size(2) - (size + 1)) * scale
  78. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  79. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  80. return xc, yc
  81. else
  82. local xcs = torch.LongTensor(tries, y:size(1), size, size)
  83. local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
  84. local rects = {}
  85. local r = torch.LongTensor(2, tries)
  86. r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
  87. r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
  88. for i = 1, tries do
  89. local xi = r[1][i]
  90. local yi = r[2][i]
  91. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  92. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  93. xcs[i]:copy(xc)
  94. lcs[i]:copy(lc)
  95. rects[i] = {xi, yi}
  96. end
  97. xcs:csub(lcs)
  98. xcs:cmul(xcs)
  99. local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
  100. local best_xi = rects[l[1][1]][1]
  101. local best_yi = rects[l[1][1]][2]
  102. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  103. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  104. return xc, yc
  105. end
  106. end
  107. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  108. local xs = {}
  109. local ns = {}
  110. local ys = {}
  111. local ls = {}
  112. for j = 1, 2 do
  113. -- TTA
  114. local xi, yi, ri
  115. if j == 1 then
  116. xi = x
  117. ni = x_noise
  118. yi = y
  119. ri = lowres_y
  120. else
  121. xi = x:transpose(2, 3):contiguous()
  122. if x_noise then
  123. ni = x_noise:transpose(2, 3):contiguous()
  124. end
  125. yi = y:transpose(2, 3):contiguous()
  126. ri = lowres_y:transpose(2, 3):contiguous()
  127. end
  128. local xv = iproc.vflip(xi)
  129. local nv
  130. if x_noise then
  131. nv = iproc.vflip(ni)
  132. end
  133. local yv = iproc.vflip(yi)
  134. local rv = iproc.vflip(ri)
  135. table.insert(xs, xi)
  136. if ni then
  137. table.insert(ns, ni)
  138. end
  139. table.insert(ys, yi)
  140. table.insert(ls, ri)
  141. table.insert(xs, xv)
  142. if nv then
  143. table.insert(ns, nv)
  144. end
  145. table.insert(ys, yv)
  146. table.insert(ls, rv)
  147. table.insert(xs, iproc.hflip(xi))
  148. if ni then
  149. table.insert(ns, iproc.hflip(ni))
  150. end
  151. table.insert(ys, iproc.hflip(yi))
  152. table.insert(ls, iproc.hflip(ri))
  153. table.insert(xs, iproc.hflip(xv))
  154. if nv then
  155. table.insert(ns, iproc.hflip(nv))
  156. end
  157. table.insert(ys, iproc.hflip(yv))
  158. table.insert(ls, iproc.hflip(rv))
  159. end
  160. return xs, ys, ls, ns
  161. end
  162. local function lowres_model()
  163. local seq = nn.Sequential()
  164. seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
  165. seq:add(nn.SpatialUpSamplingNearest(2))
  166. return seq:cuda()
  167. end
  168. local g_lowres_model = nil
  169. local g_lowres_gpu = nil
  170. function pairwise_transform_utils.low_resolution(src)
  171. --[[
  172. -- I am not sure that the following process is thraed-safe
  173. g_lowres_model = g_lowres_model or lowres_model()
  174. if g_lowres_gpu == nil then
  175. --benchmark
  176. local gpu_time = sys.clock()
  177. for i = 1, 10 do
  178. g_lowres_model:forward(src:cuda()):byte()
  179. end
  180. gpu_time = sys.clock() - gpu_time
  181. local cpu_time = sys.clock()
  182. for i = 1, 10 do
  183. gm.Image(src, "RGB", "DHW"):
  184. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  185. size(src:size(3), src:size(2), "Box"):
  186. toTensor("byte", "RGB", "DHW")
  187. end
  188. cpu_time = sys.clock() - cpu_time
  189. --print(gpu_time, cpu_time)
  190. if gpu_time < cpu_time then
  191. g_lowres_gpu = true
  192. else
  193. g_lowres_gpu = false
  194. end
  195. end
  196. if g_lowres_gpu then
  197. return g_lowres_model:forward(src:cuda()):byte()
  198. else
  199. return gm.Image(src, "RGB", "DHW"):
  200. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  201. size(src:size(3), src:size(2), "Box"):
  202. toTensor("byte", "RGB", "DHW")
  203. end
  204. --]]
  205. return gm.Image(src, "RGB", "DHW"):
  206. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  207. size(src:size(3), src:size(2), "Box"):
  208. toTensor("byte", "RGB", "DHW")
  209. end
  210. return pairwise_transform_utils