| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 | require 'image'local iproc = require 'iproc'local data_augmentation = require 'data_augmentation'local pairwise_transform_utils = {}function pairwise_transform_utils.random_half(src, p, filters)   if torch.uniform() < p then      local filter = filters[torch.random(1, #filters)]      return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)   else      return src   endendfunction pairwise_transform_utils.crop_if_large(src, max_size, mod)   local tries = 4   if src:size(2) > max_size and src:size(3) > max_size then      assert(max_size % 4 == 0)      local rect      for i = 1, tries do	 local yi = torch.random(0, src:size(2) - max_size)	 local xi = torch.random(0, src:size(3) - max_size)	 if mod then	    yi = yi - (yi % mod)	    xi = xi - (xi % mod)	 end	 rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)	 -- ignore simple background	 if rect:float():std() >= 0 then	    break	 end      end      return rect   else      return src   endendfunction pairwise_transform_utils.preprocess(src, crop_size, options)   local dest = src   local box_only = false   if options.data.filters then      if #options.data.filters == 1 and options.data.filters[1] == "Box" then	 box_only = true      end   end   if box_only then      local mod = 2 -- assert pos % 2 == 0      dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)      dest = data_augmentation.flip(dest)      dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)      dest = data_augmentation.overlay(dest, options.random_overlay_rate)      dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)      dest = iproc.crop_mod4(dest)   else      dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)      dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))      dest = data_augmentation.flip(dest)      dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)      dest = data_augmentation.overlay(dest, options.random_overlay_rate)      dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)      dest = data_augmentation.shift_1px(dest)   end   return destendfunction pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)   assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))   assert("crop_size % scale == 0", size % scale == 0)   local r = torch.uniform()   local t = "float"   if x:type() == "torch.ByteTensor" then      t = "byte"   end   if p < r then      local xi = torch.random(1, x:size(3) - (size + 1)) * scale      local yi = torch.random(1, x:size(2) - (size + 1)) * scale      local yc = iproc.crop(y, xi, yi, xi + size, yi + size)      local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)      return xc, yc   else      local best_se = 0.0      local best_xi, best_yi      local m = torch.LongTensor(y:size(1), size, size)      local targets = {}      for i = 1, tries do	 local xi = torch.random(1, x:size(3) - (size + 1)) * scale	 local yi = torch.random(1, x:size(2) - (size + 1)) * scale	 local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)	 local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)	 m:copy(xc:long()):csub(lc:long())	 m:cmul(m)	 local se = m:sum()	 if se >= best_se then	    best_xi = xi	    best_yi = yi	    best_se = se	 end      end      local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)      local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)      return xc, yc   endendfunction pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)   local xs = {}   local ns = {}   local ys = {}   local ls = {}   for j = 1, 2 do      -- TTA      local xi, yi, ri      if j == 1 then	 xi = x	 ni = x_noise	 yi = y	 ri = lowres_y      else	 xi = x:transpose(2, 3):contiguous()	 if x_noise then	    ni = x_noise:transpose(2, 3):contiguous()	 end	 yi = y:transpose(2, 3):contiguous()	 ri = lowres_y:transpose(2, 3):contiguous()      end      local xv = image.vflip(xi)      local nv      if x_noise then	 nv = image.vflip(ni)      end      local yv = image.vflip(yi)      local rv = image.vflip(ri)      table.insert(xs, xi)      if ni then	 table.insert(ns, ni)      end      table.insert(ys, yi)      table.insert(ls, ri)      table.insert(xs, xv)      if nv then	 table.insert(ns, nv)      end      table.insert(ys, yv)      table.insert(ls, rv)      table.insert(xs, image.hflip(xi))      if ni then	 table.insert(ns, image.hflip(ni))      end      table.insert(ys, image.hflip(yi))      table.insert(ls, image.hflip(ri))      table.insert(xs, image.hflip(xv))      if nv then	 table.insert(ns, image.hflip(nv))      end      table.insert(ys, image.hflip(yv))      table.insert(ls, image.hflip(rv))   end   return xs, ys, ls, nsendreturn pairwise_transform_utils
 |