pairwise_transform_utils.lua 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. require 'cunn'
  2. local iproc = require 'iproc'
  3. local gm = {}
  4. gm.Image = require 'graphicsmagick.Image'
  5. local data_augmentation = require 'data_augmentation'
  6. local pairwise_transform_utils = {}
  7. function pairwise_transform_utils.random_half(src, p, filters)
  8. if torch.uniform() < p then
  9. local filter = filters[torch.random(1, #filters)]
  10. return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
  11. else
  12. return src
  13. end
  14. end
  15. function pairwise_transform_utils.crop_if_large(src, max_size, mod)
  16. local tries = 4
  17. if src:size(2) > max_size and src:size(3) > max_size then
  18. assert(max_size % 4 == 0)
  19. local rect
  20. for i = 1, tries do
  21. local yi = torch.random(0, src:size(2) - max_size)
  22. local xi = torch.random(0, src:size(3) - max_size)
  23. if mod then
  24. yi = yi - (yi % mod)
  25. xi = xi - (xi % mod)
  26. end
  27. rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
  28. -- ignore simple background
  29. if rect:float():std() >= 0 then
  30. break
  31. end
  32. end
  33. return rect
  34. else
  35. return src
  36. end
  37. end
  38. function pairwise_transform_utils.preprocess(src, crop_size, options)
  39. local dest = src
  40. local box_only = false
  41. if options.data.filters then
  42. if #options.data.filters == 1 and options.data.filters[1] == "Box" then
  43. box_only = true
  44. end
  45. end
  46. if box_only then
  47. local mod = 2 -- assert pos % 2 == 0
  48. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
  49. dest = data_augmentation.flip(dest)
  50. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  51. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  52. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  53. dest = iproc.crop_mod4(dest)
  54. else
  55. dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
  56. dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
  57. dest = data_augmentation.flip(dest)
  58. dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
  59. dest = data_augmentation.overlay(dest, options.random_overlay_rate)
  60. dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
  61. dest = data_augmentation.shift_1px(dest)
  62. end
  63. return dest
  64. end
  65. function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
  66. assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
  67. assert("crop_size % scale == 0", size % scale == 0)
  68. local r = torch.uniform()
  69. local t = "float"
  70. if x:type() == "torch.ByteTensor" then
  71. t = "byte"
  72. end
  73. if p < r then
  74. local xi = torch.random(1, x:size(3) - (size + 1)) * scale
  75. local yi = torch.random(1, x:size(2) - (size + 1)) * scale
  76. local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
  77. local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
  78. return xc, yc
  79. else
  80. local xcs = torch.LongTensor(tries, y:size(1), size, size)
  81. local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
  82. local rects = {}
  83. local r = torch.LongTensor(2, tries)
  84. r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
  85. r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
  86. for i = 1, tries do
  87. local xi = r[1][i]
  88. local yi = r[2][i]
  89. local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
  90. local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
  91. xcs[i]:copy(xc)
  92. lcs[i]:copy(lc)
  93. rects[i] = {xi, yi}
  94. end
  95. xcs:csub(lcs)
  96. xcs:cmul(xcs)
  97. local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
  98. local best_xi = rects[l[1][1]][1]
  99. local best_yi = rects[l[1][1]][2]
  100. local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
  101. local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
  102. return xc, yc
  103. end
  104. end
  105. function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
  106. local xs = {}
  107. local ns = {}
  108. local ys = {}
  109. local ls = {}
  110. for j = 1, 2 do
  111. -- TTA
  112. local xi, yi, ri
  113. if j == 1 then
  114. xi = x
  115. ni = x_noise
  116. yi = y
  117. ri = lowres_y
  118. else
  119. xi = x:transpose(2, 3):contiguous()
  120. if x_noise then
  121. ni = x_noise:transpose(2, 3):contiguous()
  122. end
  123. yi = y:transpose(2, 3):contiguous()
  124. ri = lowres_y:transpose(2, 3):contiguous()
  125. end
  126. local xv = iproc.vflip(xi)
  127. local nv
  128. if x_noise then
  129. nv = iproc.vflip(ni)
  130. end
  131. local yv = iproc.vflip(yi)
  132. local rv = iproc.vflip(ri)
  133. table.insert(xs, xi)
  134. if ni then
  135. table.insert(ns, ni)
  136. end
  137. table.insert(ys, yi)
  138. table.insert(ls, ri)
  139. table.insert(xs, xv)
  140. if nv then
  141. table.insert(ns, nv)
  142. end
  143. table.insert(ys, yv)
  144. table.insert(ls, rv)
  145. table.insert(xs, iproc.hflip(xi))
  146. if ni then
  147. table.insert(ns, iproc.hflip(ni))
  148. end
  149. table.insert(ys, iproc.hflip(yi))
  150. table.insert(ls, iproc.hflip(ri))
  151. table.insert(xs, iproc.hflip(xv))
  152. if nv then
  153. table.insert(ns, iproc.hflip(nv))
  154. end
  155. table.insert(ys, iproc.hflip(yv))
  156. table.insert(ls, iproc.hflip(rv))
  157. end
  158. return xs, ys, ls, ns
  159. end
  160. local function lowres_model()
  161. local seq = nn.Sequential()
  162. seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
  163. seq:add(nn.SpatialUpSamplingNearest(2))
  164. return seq:cuda()
  165. end
  166. local g_lowres_model = nil
  167. local g_lowres_gpu = nil
  168. function pairwise_transform_utils.low_resolution(src)
  169. --[[
  170. -- I am not sure that the following process is thraed-safe
  171. g_lowres_model = g_lowres_model or lowres_model()
  172. if g_lowres_gpu == nil then
  173. --benchmark
  174. local gpu_time = sys.clock()
  175. for i = 1, 10 do
  176. g_lowres_model:forward(src:cuda()):byte()
  177. end
  178. gpu_time = sys.clock() - gpu_time
  179. local cpu_time = sys.clock()
  180. for i = 1, 10 do
  181. gm.Image(src, "RGB", "DHW"):
  182. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  183. size(src:size(3), src:size(2), "Box"):
  184. toTensor("byte", "RGB", "DHW")
  185. end
  186. cpu_time = sys.clock() - cpu_time
  187. --print(gpu_time, cpu_time)
  188. if gpu_time < cpu_time then
  189. g_lowres_gpu = true
  190. else
  191. g_lowres_gpu = false
  192. end
  193. end
  194. if g_lowres_gpu then
  195. return g_lowres_model:forward(src:cuda()):byte()
  196. else
  197. return gm.Image(src, "RGB", "DHW"):
  198. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  199. size(src:size(3), src:size(2), "Box"):
  200. toTensor("byte", "RGB", "DHW")
  201. end
  202. --]]
  203. return gm.Image(src, "RGB", "DHW"):
  204. size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
  205. size(src:size(3), src:size(2), "Box"):
  206. toTensor("byte", "RGB", "DHW")
  207. end
  208. return pairwise_transform_utils