Pārlūkot izejas kodu

Update random erasing; Support x padding

nagadomi 6 gadi atpakaļ
vecāks
revīzija
0fe21eef70
4 mainītis faili ar 42 papildinājumiem un 22 dzēšanām
  1. 14 6
      convert_data.lua
  2. 21 11
      lib/data_augmentation.lua
  3. 5 5
      lib/pairwise_transform_utils.lua
  4. 2 0
      lib/settings.lua

+ 14 - 6
convert_data.lua

@@ -63,16 +63,24 @@ local function crop_if_large_pair(x, y, max_size)
       return x, y
       return x, y
    end
    end
 end
 end
-local function padding_x(x, pad)
+local function padding_x(x, pad, x_zero)
    if pad > 0 then
    if pad > 0 then
-      x = iproc.padding(x, pad, pad, pad, pad)
+      if x_zero then
+	 x = iproc.zero_padding(x, pad, pad, pad, pad)
+      else
+	 x = iproc.padding(x, pad, pad, pad, pad)
+      end
    end
    end
    return x
    return x
 end
 end
-local function padding_xy(x, y, pad, y_zero)
+local function padding_xy(x, y, pad, x_zero, y_zero)
    local scale = y:size(2) / x:size(2)
    local scale = y:size(2) / x:size(2)
    if pad > 0 then
    if pad > 0 then
-      x = iproc.padding(x, pad, pad, pad, pad)
+      if x_zero then
+	 x = iproc.zero_padding(x, pad, pad, pad, pad)
+      else
+	 x = iproc.padding(x, pad, pad, pad, pad)
+      end
       if y_zero then
       if y_zero then
 	 y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
 	 y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
       else
       else
@@ -127,7 +135,7 @@ local function load_images(list)
 		     xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
 		     xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
 		  end
 		  end
 		  xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
 		  xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
-		  xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
+		  xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_x_zero, settings.padding_y_zero)
 		  if settings.grayscale then
 		  if settings.grayscale then
 		     xx = iproc.rgb2y(xx)
 		     xx = iproc.rgb2y(xx)
 		     yy = iproc.rgb2y(yy)
 		     yy = iproc.rgb2y(yy)
@@ -140,7 +148,7 @@ local function load_images(list)
 	    else
 	    else
 	       im = crop_if_large(im, settings.max_training_image_size)
 	       im = crop_if_large(im, settings.max_training_image_size)
 	       im = iproc.crop_mod4(im)
 	       im = iproc.crop_mod4(im)
-	       im = padding_x(im, settings.padding)
+	       im = padding_x(im, settings.padding, settings.padding_x_zero)
 	       local scale = 1.0
 	       local scale = 1.0
 	       if settings.random_half_rate > 0.0 then
 	       if settings.random_half_rate > 0.0 then
 		  scale = 2.0
 		  scale = 2.0

+ 21 - 11
lib/data_augmentation.lua

@@ -13,6 +13,21 @@ local function pcacov(x)
    local ce, cv = torch.symeig(c, 'V')
    local ce, cv = torch.symeig(c, 'V')
    return ce, cv
    return ce, cv
 end
 end
+
+function random_rect_size(rect_min, rect_max)
+   local r = torch.Tensor(2):uniform():cmul(torch.Tensor({rect_max - rect_min, rect_max - rect_min})):int()
+   local rect_h = r[1] + rect_min
+   local rect_w = r[2] + rect_min
+   return rect_h, rect_w
+end
+function random_rect(height, width, rect_h, rect_w)
+   local r = torch.Tensor(2):uniform():cmul(torch.Tensor({height - 1 - rect_h, width-1 - rect_w})):int()
+   local rect_y1 = r[1] + 1
+   local rect_x1 = r[2] + 1
+   local rect_x2 = rect_x1 + rect_w
+   local rect_y2 = rect_y1 + rect_h
+   return {x1 = rect_x1, y1 = rect_y1, x2 = rect_x2, y2 = rect_y2}
+end
 function data_augmentation.erase(src, p, n, rect_min, rect_max)
 function data_augmentation.erase(src, p, n, rect_min, rect_max)
    if torch.uniform() < p then
    if torch.uniform() < p then
       local src, conversion = iproc.byte2float(src)
       local src, conversion = iproc.byte2float(src)
@@ -21,17 +36,12 @@ function data_augmentation.erase(src, p, n, rect_min, rect_max)
       local height = src:size(2)
       local height = src:size(2)
       local width = src:size(3)
       local width = src:size(3)
       for i = 1, n do
       for i = 1, n do
-	 local r = torch.Tensor(4):uniform():cmul(torch.Tensor({height-1, width-1, rect_max - rect_min, rect_max - rect_min})):int()
-	 local rect_y1 = r[1] + 1
-	 local rect_x1 = r[2] + 1
-	 local rect_h = r[3] + rect_min
-	 local rect_w = r[4] + rect_min
-	 local rect_x2 = math.min(rect_x1 + rect_w, width)
-	 local rect_y2 = math.min(rect_y1 + rect_h, height)
-	 local sub_rect = src:sub(1, ch, rect_y1, rect_y2, rect_x1, rect_x2)
-	 for i = 1, ch do
-	    sub_rect[i]:fill(src[i][rect_y1][rect_x1])
-	 end
+	 local rect_h, rect_w = random_rect_size(rect_min, rect_max)
+	 local rect1 = random_rect(height, width, rect_h, rect_w)
+	 local rect2 = random_rect(height, width, rect_h, rect_w)
+	 dest_rect = src:sub(1, ch, rect1.y1, rect1.y2, rect1.x1, rect1.x2)
+	 src_rect = src:sub(1, ch, rect2.y1, rect2.y2, rect2.x1, rect2.x2)
+	 dest_rect:copy(src_rect:clone())
       end
       end
       if conversion then
       if conversion then
 	 src = iproc.float2byte(src)
 	 src = iproc.float2byte(src)

+ 5 - 5
lib/pairwise_transform_utils.lua

@@ -92,6 +92,11 @@ end
 function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
 function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
 
 
    x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
    x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
+   x = data_augmentation.erase(x, 
+			       options.random_erasing_rate,
+			       options.random_erasing_n,
+			       options.random_erasing_rect_min,
+			       options.random_erasing_rect_max)
    x, y = data_augmentation.pairwise_rotate(x, y,
    x, y = data_augmentation.pairwise_rotate(x, y,
 					    options.random_pairwise_rotate_rate,
 					    options.random_pairwise_rotate_rate,
 					    options.random_pairwise_rotate_min,
 					    options.random_pairwise_rotate_min,
@@ -105,11 +110,6 @@ function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
 					   scale_max)
 					   scale_max)
    x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
    x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
    x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
    x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
-   x = data_augmentation.erase(x, 
-			       options.random_erasing_rate,
-			       options.random_erasing_n,
-			       options.random_erasing_rect_min,
-			       options.random_erasing_rect_max)
    x = iproc.crop_mod4(x)
    x = iproc.crop_mod4(x)
    y = iproc.crop_mod4(y)
    y = iproc.crop_mod4(y)
    return x, y
    return x, y

+ 2 - 0
lib/settings.lua

@@ -82,6 +82,7 @@ cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
 cmd:option("-update_criterion", "mse", 'mse|loss')
 cmd:option("-update_criterion", "mse", 'mse|loss')
 cmd:option("-padding", 0, 'replication padding size')
 cmd:option("-padding", 0, 'replication padding size')
 cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
 cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
+cmd:option("-padding_x_zero", 0, 'zero padding x for segmentation (0|1)')
 cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
 cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
 cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
 cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
 cmd:option("-invert_x", 0, 'invert x image in convert_lua')
 cmd:option("-invert_x", 0, 'invert x image in convert_lua')
@@ -104,6 +105,7 @@ to_bool(settings, "use_transparent_png")
 to_bool(settings, "pairwise_y_binary")
 to_bool(settings, "pairwise_y_binary")
 to_bool(settings, "pairwise_flip")
 to_bool(settings, "pairwise_flip")
 to_bool(settings, "padding_y_zero")
 to_bool(settings, "padding_y_zero")
+to_bool(settings, "padding_x_zero")
 to_bool(settings, "grayscale")
 to_bool(settings, "grayscale")
 to_bool(settings, "validation_filename_split")
 to_bool(settings, "validation_filename_split")
 to_bool(settings, "invert_x")
 to_bool(settings, "invert_x")