123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- require 'cunn'
- local iproc = require 'iproc'
- local gm = {}
- gm.Image = require 'graphicsmagick.Image'
- local data_augmentation = require 'data_augmentation'
- local pairwise_transform_utils = {}
- function pairwise_transform_utils.random_half(src, p, filters)
- if torch.uniform() < p then
- local filter = filters[torch.random(1, #filters)]
- return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
- else
- return src
- end
- end
- function pairwise_transform_utils.crop_if_large(src, max_size, mod)
- local tries = 4
- if src:size(2) > max_size and src:size(3) > max_size then
- assert(max_size % 4 == 0)
- local rect
- for i = 1, tries do
- local yi = torch.random(0, src:size(2) - max_size)
- local xi = torch.random(0, src:size(3) - max_size)
- if mod then
- yi = yi - (yi % mod)
- xi = xi - (xi % mod)
- end
- rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
- -- ignore simple background
- if rect:float():std() >= 0 then
- break
- end
- end
- return rect
- else
- return src
- end
- end
- function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
- local tries = 4
- if y:size(2) > max_size and y:size(3) > max_size then
- assert(max_size % 4 == 0)
- local rect_x, rect_y
- for i = 1, tries do
- local yi = torch.random(0, y:size(2) - max_size)
- local xi = torch.random(0, y:size(3) - max_size)
- if mod then
- yi = yi - (yi % mod)
- xi = xi - (xi % mod)
- end
- rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
- rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
- -- ignore simple background
- if rect_y:float():std() >= 0 then
- break
- end
- end
- return rect_x, rect_y
- else
- return x, y
- end
- end
- function pairwise_transform_utils.preprocess(src, crop_size, options)
- local dest = src
- local box_only = false
- if options.data.filters then
- if #options.data.filters == 1 and options.data.filters[1] == "Box" then
- box_only = true
- end
- end
- if box_only then
- local mod = 2 -- assert pos % 2 == 0
- dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
- dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
- dest = data_augmentation.overlay(dest, options.random_overlay_rate)
- dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
- dest = iproc.crop_mod4(dest)
- else
- dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
- dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
- dest = data_augmentation.blur(dest, options.random_blur_rate,
- options.random_blur_size,
- options.random_blur_sigma_min,
- options.random_blur_sigma_max)
- dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
- dest = data_augmentation.overlay(dest, options.random_overlay_rate)
- dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
- dest = data_augmentation.shift_1px(dest)
- end
- return dest
- end
- function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
- x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
- x, y = data_augmentation.pairwise_rotate(x, y,
- options.random_pairwise_rotate_rate,
- options.random_pairwise_rotate_min,
- options.random_pairwise_rotate_max)
- local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
- local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
- x, y = data_augmentation.pairwise_scale(x, y,
- options.random_pairwise_scale_rate,
- scale_min,
- scale_max)
- x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
- x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
- x = iproc.crop_mod4(x)
- y = iproc.crop_mod4(y)
- if options.pairwise_y_binary then
- y[torch.lt(y, 128)] = 0
- y[torch.gt(y, 0)] = 255
- end
- return x, y
- end
- function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
- assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
- assert("crop_size % scale == 0", size % scale == 0)
- local r = torch.uniform()
- local t = "float"
- if x:type() == "torch.ByteTensor" then
- t = "byte"
- end
- if p < r then
- local xi = torch.random(1, x:size(3) - (size + 1)) * scale
- local yi = torch.random(1, x:size(2) - (size + 1)) * scale
- local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
- local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
- return xc, yc
- else
- local xcs = torch.LongTensor(tries, y:size(1), size, size)
- local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
- local rects = {}
- local r = torch.LongTensor(2, tries)
- r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
- r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
- for i = 1, tries do
- local xi = r[1][i]
- local yi = r[2][i]
- local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
- local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
- xcs[i]:copy(xc)
- lcs[i]:copy(lc)
- rects[i] = {xi, yi}
- end
- xcs:csub(lcs)
- xcs:cmul(xcs)
- local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
- local best_xi = rects[l[1][1]][1]
- local best_yi = rects[l[1][1]][2]
- local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
- local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
- return xc, yc
- end
- end
- function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
- local xs = {}
- local ns = {}
- local ys = {}
- local ls = {}
- for j = 1, 2 do
- -- TTA
- local xi, yi, ri
- if j == 1 then
- xi = x
- ni = x_noise
- yi = y
- ri = lowres_y
- else
- xi = x:transpose(2, 3):contiguous()
- if x_noise then
- ni = x_noise:transpose(2, 3):contiguous()
- end
- yi = y:transpose(2, 3):contiguous()
- ri = lowres_y:transpose(2, 3):contiguous()
- end
- local xv = iproc.vflip(xi)
- local nv
- if x_noise then
- nv = iproc.vflip(ni)
- end
- local yv = iproc.vflip(yi)
- local rv = iproc.vflip(ri)
- table.insert(xs, xi)
- if ni then
- table.insert(ns, ni)
- end
- table.insert(ys, yi)
- table.insert(ls, ri)
- table.insert(xs, xv)
- if nv then
- table.insert(ns, nv)
- end
- table.insert(ys, yv)
- table.insert(ls, rv)
- table.insert(xs, iproc.hflip(xi))
- if ni then
- table.insert(ns, iproc.hflip(ni))
- end
- table.insert(ys, iproc.hflip(yi))
- table.insert(ls, iproc.hflip(ri))
- table.insert(xs, iproc.hflip(xv))
- if nv then
- table.insert(ns, iproc.hflip(nv))
- end
- table.insert(ys, iproc.hflip(yv))
- table.insert(ls, iproc.hflip(rv))
- end
- return xs, ys, ls, ns
- end
- local function lowres_model()
- local seq = nn.Sequential()
- seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
- seq:add(nn.SpatialUpSamplingNearest(2))
- return seq:cuda()
- end
- local g_lowres_model = nil
- local g_lowres_gpu = nil
- function pairwise_transform_utils.low_resolution(src)
- --[[
- -- I am not sure that the following process is thraed-safe
- g_lowres_model = g_lowres_model or lowres_model()
- if g_lowres_gpu == nil then
- --benchmark
- local gpu_time = sys.clock()
- for i = 1, 10 do
- g_lowres_model:forward(src:cuda()):byte()
- end
- gpu_time = sys.clock() - gpu_time
- local cpu_time = sys.clock()
- for i = 1, 10 do
- gm.Image(src, "RGB", "DHW"):
- size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
- size(src:size(3), src:size(2), "Box"):
- toTensor("byte", "RGB", "DHW")
- end
- cpu_time = sys.clock() - cpu_time
- --print(gpu_time, cpu_time)
- if gpu_time < cpu_time then
- g_lowres_gpu = true
- else
- g_lowres_gpu = false
- end
- end
- if g_lowres_gpu then
- return g_lowres_model:forward(src:cuda()):byte()
- else
- return gm.Image(src, "RGB", "DHW"):
- size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
- size(src:size(3), src:size(2), "Box"):
- toTensor("byte", "RGB", "DHW")
- end
- --]]
- return gm.Image(src, "RGB", "DHW"):
- size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
- size(src:size(3), src:size(2), "Box"):
- toTensor("byte", "RGB", "DHW")
- end
- return pairwise_transform_utils
|