Kaynağa Gözat

add data augmentation method that uses overlay

nagadomi 10 yıl önce
ebeveyn
işleme
54580ba8c0
3 değiştirilmiş dosya ile 58 ekleme ve 8 silme
  1. 47 8
      lib/pairwise_transform.lua
  2. 7 0
      lib/settings.lua
  3. 4 0
      train.lua

+ 47 - 8
lib/pairwise_transform.lua

@@ -65,9 +65,28 @@ local function flip_augment(x, y)
       return x
    end
 end
+local function overlay_augment(src, p)
+   p = p or 0.25
+   if torch.uniform() > (1.0 - p) then
+      local r = torch.uniform(0.2, 0.8)
+      local t = "float"
+      if src:type() == "torch.ByteTensor" then
+	 src = src:float():div(255)
+	 t = "byte"
+      end
+      local flip = flip_augment(src)
+      flip:mul(r):add(src * (1.0 - r))
+      if t == "byte" then
+	 flip = flip:mul(255):byte()
+      end
+      return flip
+   else
+      return src
+   end
+end
 local INTERPOLATION_PADDING = 16
 function pairwise_transform.scale(src, scale, size, offset, options)
-   options = options or {color_noise = false, random_half = true, rgb = true}
+   options = options or {color_noise = false, overlay = false, random_half = true, rgb = true}
    if options.random_half then
       src = random_half(src)
    end
@@ -92,6 +111,9 @@ function pairwise_transform.scale(src, scale, size, offset, options)
    if options.color_noise then
       y = color_noise(y)
    end
+   if options.overlay then
+      y = overlay_augment(y)
+   end
    local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
    x = iproc.scale(x, y:size(3), y:size(2))
    y = y:float():div(255)
@@ -109,7 +131,7 @@ function pairwise_transform.scale(src, scale, size, offset, options)
    return x, y
 end
 function pairwise_transform.jpeg_(src, quality, size, offset, options)
-   options = options or {color_noise = false, random_half = true, rgb = true}
+   options = options or {color_noise = false, overlay = false, random_half = true, rgb = true}
    if options.random_half then
       src = random_half(src)
    end
@@ -121,6 +143,9 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
    if options.color_noise then
       y = color_noise(y)
    end
+   if options.overlay then
+      y = overlay_augment(y)
+   end
    x = y
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
@@ -236,6 +261,10 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
    if options.color_noise then
       y = color_noise(y)
    end
+   if options.overlay then
+      y = overlay_augment(y)
+   end
+   
    x = y
    x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
    for i = 1, #quality do
@@ -325,12 +354,12 @@ end
 local function test_jpeg()
    local loader = require './image_loader'
    local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
-   local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, false)
+   local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, {})
    image.display({image = y, legend = "y:0"})
    image.display({image = x, legend = "x:0"})
    for i = 2, 9 do
-      local y, x = pairwise_transform.jpeg_(pairwise_transform.random_half(src),
-					    {i * 10}, 128, 0, {color_noise = false, random_half = true})
+      local y, x = pairwise_transform.jpeg_(random_half(src),
+					    {i * 10}, 128, 0, {color_noise = false, random_half = true, overlay = true, rgb = true})
       image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0})
       image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0})
       --print(x:mean(), y:mean())
@@ -342,7 +371,7 @@ local function test_scale()
    local loader = require './image_loader'
    local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
    for i = 1, 9 do
-      local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_noise = true, random_half = true, rgb = true})
+      local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_noise = true, random_half = true, rgb = true, overlay = true})
       image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
       image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
       print(y:size(), x:size())
@@ -354,14 +383,14 @@ local function test_jpeg_scale()
    local loader = require './image_loader'
    local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")   
    for i = 1, 9 do
-      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_noise = true, random_half = true})
+      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_noise = true, random_half = true, overlay = true})
       image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
       image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
       print(y:size(), x:size())
       --print(x:mean(), y:mean())
    end
    for i = 1, 9 do
-      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_noise = true, random_half = true})
+      local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_noise = true, random_half = true, overlay = true})
       image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
       image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
       print(y:size(), x:size())
@@ -376,9 +405,19 @@ local function test_color_noise()
       image.display(color_noise(src))
    end
 end
+local function test_overlay()
+   torch.setdefaulttensortype('torch.FloatTensor')
+   local loader = require './image_loader'
+   local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
+   for i = 1, 10 do
+      image.display(overlay_augment(src, 1.0))
+   end
+end
+
 --test_scale()
 --test_jpeg()
 --test_jpeg_scale()
 --test_color_noise()
+--test_overlay()
 
 return pairwise_transform

+ 7 - 0
lib/settings.lua

@@ -25,6 +25,7 @@ cmd:option("-noise_level", 1, '(1|2)')
 cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')
 cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)')
+cmd:option("-overlay", 0, 'enable data augmentation using overlay (1|0)')
 cmd:option("-scale", 2.0, 'scale')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
@@ -69,6 +70,12 @@ if settings.color_noise == 1 then
 else
    settings.color_noise = false
 end
+if settings.overlay == 1 then
+   settings.overlay = true
+else
+   settings.overlay = false
+end
+
 torch.setnumthreads(settings.core)
 
 settings.images = string.format("%s/images.t7", settings.data_dir)

+ 4 - 0
train.lua

@@ -80,11 +80,13 @@ local function train()
    local transformer = function(x, is_validation)
       if is_validation == nil then is_validation = false end
       local color_noise = (not is_validation) and settings.color_noise
+      local overlay = (not is_validation) and settings.overlay
       if settings.method == "scale" then
 	 return pairwise_transform.scale(x,
 					 settings.scale,
 					 settings.crop_size, offset,
 					 { color_noise = color_noise,
+					   overlay = overlay,
 					   random_half = settings.random_half,
 					   rgb = (settings.color == "rgb")
 					 })
@@ -94,6 +96,7 @@ local function train()
 					settings.noise_level,
 					settings.crop_size, offset,
 					{ color_noise = color_noise,
+					  overlay = overlay,
 					  random_half = settings.random_half,
 					  rgb = (settings.color == "rgb")
 					})
@@ -104,6 +107,7 @@ local function train()
 					      settings.noise_level,
 					      settings.crop_size, offset,
 					      { color_noise = color_noise,
+						overlay = overlay,
 						random_half = settings.random_half,
 						rgb = (settings.color == "rgb")
 					      })