Forráskód Böngészése

merge random erasing

nagadomi 6 éve
szülő
commit
93cd40a53c
4 módosított fájl, 52 hozzáadás és 1 törlés
  1. 28 0
      lib/data_augmentation.lua
  2. 5 1
      lib/pairwise_transform_utils.lua
  3. 4 0
      lib/settings.lua
  4. 15 0
      train.lua

+ 28 - 0
lib/data_augmentation.lua

@@ -13,6 +13,34 @@ local function pcacov(x)
    local ce, cv = torch.symeig(c, 'V')
    return ce, cv
 end
+function data_augmentation.erase(src, p, n, rect_min, rect_max)
+   if torch.uniform() < p then
+      local src, conversion = iproc.byte2float(src)
+      src = src:contiguous():clone()
+      local ch = src:size(1)
+      local height = src:size(2)
+      local width = src:size(3)
+      for i = 1, n do
+	 local r = torch.Tensor(4):uniform():cmul(torch.Tensor({height-1, width-1, rect_max - rect_min, rect_max - rect_min})):int()
+	 local rect_y1 = r[1] + 1
+	 local rect_x1 = r[2] + 1
+	 local rect_h = r[3] + rect_min
+	 local rect_w = r[4] + rect_min
+	 local rect_x2 = math.min(rect_x1 + rect_w, width)
+	 local rect_y2 = math.min(rect_y1 + rect_h, height)
+	 local sub_rect = src:sub(1, ch, rect_y1, rect_y2, rect_x1, rect_x2)
+	 for i = 1, ch do
+	    sub_rect[i]:fill(src[i][rect_y1][rect_x1])
+	 end
+      end
+      if conversion then
+	 src = iproc.float2byte(src)
+      end
+      return src
+   else
+      return src
+   end
+end
 function data_augmentation.color_noise(src, p, factor)
    factor = factor or 0.1
    if torch.uniform() < p then

+ 5 - 1
lib/pairwise_transform_utils.lua

@@ -105,7 +105,11 @@ function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
 					   scale_max)
    x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
    x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
-
+   x = data_augmentation.erase(x, 
+			       options.random_erasing_rate,
+			       options.random_erasing_n,
+			       options.random_erasing_rect_min,
+			       options.random_erasing_rect_max)
    x = iproc.crop_mod4(x)
    y = iproc.crop_mod4(y)
    return x, y

+ 4 - 0
lib/settings.lua

@@ -44,6 +44,10 @@ cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairw
 cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate')
 cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
 cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
+cmd:option("-random_erasing_rate", 0.0, 'data augmentation using random erasing for user method')
+cmd:option("-random_erasing_n", 1, 'number of erasing')
+cmd:option("-random_erasing_rect_min", 8, 'rect min size')
+cmd:option("-random_erasing_rect_max", 32, 'rect max size')
 cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
 cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
 cmd:option("-scale", 2.0, 'scale factor (2)')

+ 15 - 0
train.lua

@@ -215,6 +215,17 @@ local function transform_pool_init(has_resize, offset)
 						    settings.crop_size, offset,
 						    n, conf)
 	    elseif settings.method == "user" then
+	       local random_erasing_rate = 0
+	       local random_erasing_n = 0
+	       local random_erasing_rect_min = 0
+	       local random_erasing_rect_max = 0
+	       if is_validation then
+	       else
+		  random_erasing_rate = settings.random_erasing_rate
+		  random_erasing_n = settings.random_erasing_n
+		  random_erasing_rect_min = settings.random_erasing_rect_min
+		  random_erasing_rect_max = settings.random_erasing_rect_max
+	       end
 	       local conf = tablex.update({
 		     gcn = settings.gcn,
 		     max_size = settings.max_size,
@@ -230,6 +241,10 @@ local function transform_pool_init(has_resize, offset)
 		     random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
 		     pairwise_y_binary = settings.pairwise_y_binary,
 		     pairwise_flip = settings.pairwise_flip,
+		     random_erasing_rate = random_erasing_rate,
+		     random_erasing_n = random_erasing_n,
+		     random_erasing_rect_min = random_erasing_rect_min,
+		     random_erasing_rect_max = random_erasing_rect_max,
 		     rgb = (settings.color == "rgb")}, meta)
 	       return pairwise_transform.user(x, y,
 					      settings.crop_size, offset,