Quellcode durchsuchen

Add data augmentation for user method

nagadomi vor 8 Jahren
Ursprung
Commit
b066761cdc
6 geänderte Dateien mit 163 neuen und 27 gelöschten Zeilen
  1. 43 0
      lib/data_augmentation.lua
  2. 42 2
      lib/iproc.lua
  3. 1 25
      lib/pairwise_transform_user.lua
  4. 51 0
      lib/pairwise_transform_utils.lua
  5. 10 0
      lib/settings.lua
  6. 16 0
      train.lua

+ 43 - 0
lib/data_augmentation.lua

@@ -96,6 +96,49 @@ function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
       return src
    end
 end
+function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
+   if torch.uniform() < p then
+      assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
+      local scale = torch.uniform(scale_min, scale_max)
+      local h = math.floor(x:size(2) * scale)
+      local w = math.floor(x:size(3) * scale)
+      x = iproc.scale(x, w, h, "Triangle")
+      y = iproc.scale(y, w, h, "Triangle")
+      return x, y
+   else
+      return x, y
+   end
+end
+function data_augmentation.pairwise_rotate(x, y, p, r_min, r_max)
+   if torch.uniform() < p then
+      assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
+      local r = torch.uniform(r_min, r_max) / 360.0 * math.pi
+      x = iproc.rotate(x, r)
+      y = iproc.rotate(y, r)
+      return x, y
+   else
+      return x, y
+   end
+end
+function data_augmentation.pairwise_negate(x, y, p)
+   if torch.uniform() < p then
+      assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
+      x = iproc.negate(x, r)
+      y = iproc.rotate(y, r)
+      return x, y
+   else
+      return x, y
+   end
+end
+function data_augmentation.pairwise_negate_x(x, y, p)
+   if torch.uniform() < p then
+      assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
+      x = iproc.negate(x, r)
+      return x, y
+   else
+      return x, y
+   end
+end
 function data_augmentation.shift_1px(src)
    -- reducing the even/odd issue in nearest neighbor scaler.
    local direction = torch.random(1, 4)

+ 42 - 2
lib/iproc.lua

@@ -1,8 +1,7 @@
 local gm = {}
 gm.Image = require 'graphicsmagick.Image'
-local image = nil
 require 'dok'
-require 'image'
+local image = require 'image'
 local iproc = {}
 local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
 
@@ -158,6 +157,47 @@ function iproc.vflip(src)
    local im = gm.Image(src, color, "DHW")
    return im:flip():toTensor(t, color, "DHW")
 end
+local function rotate_with_warp(src, dst, theta, mode)
+  local height
+  local width
+  if src:dim() == 2 then
+    height = src:size(1)
+    width = src:size(2)
+  elseif src:dim() == 3 then
+    height = src:size(2)
+    width = src:size(3)
+  else
+    dok.error('src image must be 2D or 3D', 'image.rotate')
+  end
+  local flow = torch.Tensor(2, height, width)
+  local kernel = torch.Tensor({{math.cos(-theta), -math.sin(-theta)},
+			       {math.sin(-theta), math.cos(-theta)}})
+  flow[1] = torch.ger(torch.linspace(0, 1, height), torch.ones(width))
+  flow[1]:mul(-(height -1)):add(math.floor(height / 2 + 0.5))
+  flow[2] = torch.ger(torch.ones(height), torch.linspace(0, 1, width))
+  flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
+  flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
+  dst:resizeAs(src)
+  return image.warp(dst, src, flow, mode, true, 'pad')
+end
+function iproc.rotate(src, theta)
+   local conversion
+   src, conversion = iproc.byte2float(src)
+   local dest = torch.Tensor():typeAs(src):resizeAs(src)
+   rotate_with_warp(src, dest, theta, 'bicubic')
+   dest:clamp(0, 1)
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
+end
+function iproc.negate(src)
+   if src:type() == "torch.ByteTensor" then
+      return -src + 255
+   else
+      return -src + 1
+   end
+end
 
 function iproc.gaussian2d(kernel_size, sigma)
    sigma = sigma or 1

+ 1 - 25
lib/pairwise_transform_user.lua

@@ -4,37 +4,13 @@ local gm = {}
 gm.Image = require 'graphicsmagick.Image'
 local pairwise_transform = {}
 
-local function crop_if_large(x, y, scale_y, max_size, mod)
-   local tries = 4
-   if y:size(2) > max_size and y:size(3) > max_size then
-      assert(max_size % 4 == 0)
-      local rect_x, rect_y
-      for i = 1, tries do
-	 local yi = torch.random(0, y:size(2) - max_size)
-	 local xi = torch.random(0, y:size(3) - max_size)
-	 if mod then
-	    yi = yi - (yi % mod)
-	    xi = xi - (xi % mod)
-	 end
-	 rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
-	 rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
-	 -- ignore simple background
-	 if rect_y:float():std() >= 0 then
-	    break
-	 end
-      end
-      return rect_x, rect_y
-   else
-      return x, y
-   end
-end
 function pairwise_transform.user(x, y, size, offset, n, options)
    assert(x:size(1) == y:size(1))
 
    local scale_y = y:size(2) / x:size(2)
    assert(x:size(3) == y:size(3) / scale_y)
 
-   x, y = crop_if_large(x, y, scale_y, options.max_size, scale_y)
+   x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
    assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
    local batch = {}
    local lowres_y = pairwise_utils.low_resolution(y)

+ 51 - 0
lib/pairwise_transform_utils.lua

@@ -36,6 +36,30 @@ function pairwise_transform_utils.crop_if_large(src, max_size, mod)
       return src
    end
 end
+function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
+   local tries = 4
+   if y:size(2) > max_size and y:size(3) > max_size then
+      assert(max_size % 4 == 0)
+      local rect_x, rect_y
+      for i = 1, tries do
+	 local yi = torch.random(0, y:size(2) - max_size)
+	 local xi = torch.random(0, y:size(3) - max_size)
+	 if mod then
+	    yi = yi - (yi % mod)
+	    xi = xi - (xi % mod)
+	 end
+	 rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
+	 rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
+	 -- ignore simple background
+	 if rect_y:float():std() >= 0 then
+	    break
+	 end
+      end
+      return rect_x, rect_y
+   else
+      return x, y
+   end
+end
 function pairwise_transform_utils.preprocess(src, crop_size, options)
    local dest = src
    local box_only = false
@@ -65,6 +89,33 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
    end
    return dest
 end
+function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
+
+   x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
+   x, y = data_augmentation.pairwise_rotate(x, y,
+					    options.random_pairwise_rotate_rate,
+					    options.random_pairwise_rotate_min,
+					    options.random_pairwise_rotate_max)
+
+   local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
+   local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
+   x, y = data_augmentation.pairwise_scale(x, y,
+					   options.random_pairwise_scale_rate,
+					   scale_min,
+					   scale_max)
+   x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
+   x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
+
+   x = iproc.crop_mod4(x)
+   y = iproc.crop_mod4(y)
+
+   if options.pairwise_y_binary then
+      y[torch.lt(y, 128)] = 0
+      y[torch.gt(y, 0)] = 255
+   end
+
+   return x, y
+end
 function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
    assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
    assert("crop_size % scale == 0", size % scale == 0)

+ 10 - 0
lib/settings.lua

@@ -37,6 +37,15 @@ cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0
 cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)')
 cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur')
 cmd:option("-random_blur_sigma_max", 0.75, 'max sigma for random gaussian blur')
+cmd:option("-random_pairwise_scale_rate", 0.0, 'data augmentation using pairwise resize for user method')
+cmd:option("-random_pairwise_scale_min", 0.85, 'min scale factor for random pairwise scale')
+cmd:option("-random_pairwise_scale_max", 1.176, 'max scale factor for random pairwise scale')
+cmd:option("-random_pairwise_rotate_rate", 0.0, 'data augmentation using pairwise resize for user method')
+cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairwise rotate')
+cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate')
+cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
+cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
+cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
 cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-crop_size", 48, 'crop size')
@@ -81,6 +90,7 @@ end
 to_bool(settings, "plot")
 to_bool(settings, "save_history")
 to_bool(settings, "use_transparent_png")
+to_bool(settings, "pairwise_y_binary")
 
 if settings.plot then
    require 'gnuplot'

+ 16 - 0
train.lua

@@ -179,6 +179,15 @@ local function transform_pool_init(has_resize, offset)
 		     max_size = settings.max_size,
 		     active_cropping_rate = active_cropping_rate,
 		     active_cropping_tries = active_cropping_tries,
+		     random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
+		     random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
+		     random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
+		     random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
+		     random_pairwise_scale_min = settings.random_pairwise_scale_min,
+		     random_pairwise_scale_max = settings.random_pairwise_scale_max,
+		     random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
+		     random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
+		     pairwise_y_binary = settings.pairwise_y_binary,
 		     rgb = (settings.color == "rgb")}, meta)
 	       return pairwise_transform.user(x, y,
 					      settings.crop_size, offset,
@@ -393,6 +402,13 @@ local function train()
    else
       model = srcnn.create(settings.model, settings.backend, settings.color)
    end
+   if model.w2nn_input_size then
+      if settings.crop_size ~= model.w2nn_input_size then
+	 io.stderr:write(string.format("warning: crop_size is replaced with %d\n",
+				       model.w2nn_input_size))
+	 settings.crop_size = model.w2nn_input_size
+      end
+   end
    dir.makepath(settings.model_dir)
 
    local offset = reconstruct.offset_size(model)