Sfoglia il codice sorgente

Add random blur method for data augmentation

nagadomi 8 anni fa
parent
commit
5a3d012f4e
5 ha cambiato i file con 94 aggiunte e 1 eliminazioni
  1. 53 0
      lib/data_augmentation.lua
  2. 21 1
      lib/iproc.lua
  3. 4 0
      lib/pairwise_transform_utils.lua
  4. 4 0
      lib/settings.lua
  5. 12 0
      train.lua

+ 53 - 0
lib/data_augmentation.lua

@@ -1,3 +1,5 @@
+require 'pl'
+require 'cunn'
 local iproc = require 'iproc'
 local gm = {}
 gm.Image = require 'graphicsmagick.Image'
@@ -69,6 +71,41 @@ function data_augmentation.unsharp_mask(src, p)
       return src
    end
 end
+data_augmentation.blur_conv = {}
+function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
+   size = size or "3"
+   filters = utils.split(size, ",")
+   for i = 1, #filters do
+      local s = tonumber(filters[i])
+      filters[i] = s
+      if not data_augmentation.blur_conv[s] then
+	 data_augmentation.blur_conv[s] = nn.SpatialConvolutionMM(1, 1, s, s, 1, 1, (s - 1) / 2, (s - 1) / 2):noBias():cuda()
+      end
+   end
+   if torch.uniform() < p then
+      local src, conversion = iproc.byte2float(src)
+      local kernel_size = filters[torch.random(1, #filters)]
+      local sigma
+      if sigma_min == sigma_max then
+	 sigma = sigma_min
+      else
+	 sigma = torch.uniform(sigma_min, sigma_max)
+      end
+      local kernel = iproc.gaussian2d(kernel_size, sigma)
+      data_augmentation.blur_conv[kernel_size].weight:copy(kernel)
+      local dest = torch.Tensor(3, src:size(2), src:size(3))
+      dest[1]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[1]:reshape(1, src:size(2), src:size(3)):cuda()))
+      dest[2]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[2]:reshape(1, src:size(2), src:size(3)):cuda()))
+      dest[3]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[3]:reshape(1, src:size(2), src:size(3)):cuda()))
+
+      if conversion then
+	 dest = iproc.float2byte(dest)
+      end
+      return dest
+   else
+      return src
+   end
+end
 function data_augmentation.shift_1px(src)
    -- reducing the even/odd issue in nearest neighbor scaler.
    local direction = torch.random(1, 4)
@@ -119,4 +156,20 @@ function data_augmentation.flip(src)
    end
    return dest
 end
+
+local function test_blur()
+   torch.setdefaulttensortype("torch.FloatTensor")
+   local image =require 'image'
+   local src = image.lena()
+
+   image.display({image = src, min=0, max=1})
+   local dest = data_augmentation.blur(src, 1.0, "3,5", 0.5, 0.6)
+   image.display({image = dest, min=0, max=1})
+   dest = data_augmentation.blur(src, 1.0, "3", 1.0, 1.0)
+   image.display({image = dest, min=0, max=1})
+   dest = data_augmentation.blur(src, 1.0, "5", 0.75, 0.75)
+   image.display({image = dest, min=0, max=1})
+end
+--test_blur()
+
 return data_augmentation

+ 21 - 1
lib/iproc.lua

@@ -254,7 +254,19 @@ function iproc.yuv2rgb(...)
    -- return RGB image
    return output
 end
-
+function iproc.gaussian2d(kernel_size, sigma)
+   sigma = sigma or 1
+   local kernel = torch.Tensor(kernel_size, kernel_size)
+   local u = math.floor(kernel_size / 2) + 1
+   local amp = (1 / math.sqrt(2 * math.pi * sigma^2))
+   for x = 1, kernel_size do
+      for y = 1, kernel_size do
+	 kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2))
+      end
+   end
+   kernel:div(kernel:sum())
+   return kernel
+end
 local function test_conversion()
    local a = torch.linspace(0, 255, 256):float():div(255.0)
    local b = iproc.float2byte(a)
@@ -286,9 +298,17 @@ local function test_flip()
    print((image.vflip(src) - iproc.vflip(src)):sum())
    print((image.vflip(src_byte) - iproc.vflip(src_byte)):sum())
 end
+local function test_gaussian2d()
+   local t = {3, 5, 7}
+   for i = 1, #t do
+      local kp = iproc.gaussian2d(t[i], 0.5)
+      print(kp)
+   end
+end
 
 --test_conversion()
 --test_flip()
+--test_gaussian2d()
 
 return iproc
 

+ 4 - 0
lib/pairwise_transform_utils.lua

@@ -56,6 +56,10 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
       dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
       dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
       dest = data_augmentation.flip(dest)
+      dest = data_augmentation.blur(dest, options.random_blur_rate,
+				    options.random_blur_size, 
+				    options.random_blur_min,
+				    options.random_blur_max)
       dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
       dest = data_augmentation.overlay(dest, options.random_overlay_rate)
       dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)

+ 4 - 0
lib/settings.lua

@@ -32,6 +32,10 @@ cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise
 cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
 cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
 cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
+cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0-1.0)')
+cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)')
+cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur')
+cmd:option("-random_blur_sigma_max", 0.75, 'max sigma for random gaussian blur')
 cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-crop_size", 48, 'crop size')

+ 12 - 0
train.lua

@@ -97,6 +97,10 @@ local function transform_pool_init(has_resize, offset)
 		     random_color_noise_rate = random_color_noise_rate,
 		     random_overlay_rate = random_overlay_rate,
 		     random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
+		     random_blur_rate = settings.random_blur_rate,
+		     random_blur_size = settings.random_blur_size,
+		     random_blur_sigma_min = settings.random_blur_sigma_min,
+		     random_blur_sigma_max = settings.random_blur_sigma_max,
 		     max_size = settings.max_size,
 		     active_cropping_rate = active_cropping_rate,
 		     active_cropping_tries = active_cropping_tries,
@@ -114,6 +118,10 @@ local function transform_pool_init(has_resize, offset)
 		     random_color_noise_rate = random_color_noise_rate,
 		     random_overlay_rate = random_overlay_rate,
 		     random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
+		     random_blur_rate = settings.random_blur_rate,
+		     random_blur_size = settings.random_blur_size,
+		     random_blur_sigma_min = settings.random_blur_sigma_min,
+		     random_blur_sigma_max = settings.random_blur_sigma_max,
 		     max_size = settings.max_size,
 		     jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
 		     active_cropping_rate = active_cropping_rate,
@@ -132,6 +140,10 @@ local function transform_pool_init(has_resize, offset)
 		     random_color_noise_rate = random_color_noise_rate,
 		     random_overlay_rate = random_overlay_rate,
 		     random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
+		     random_blur_rate = settings.random_blur_rate,
+		     random_blur_size = settings.random_blur_size,
+		     random_blur_sigma_min = settings.random_blur_sigma_min,
+		     random_blur_sigma_max = settings.random_blur_sigma_max,
 		     max_size = settings.max_size,
 		     jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
 		     nr_rate = settings.nr_rate,