Ver código fonte

tunable parameters

nagadomi 9 anos atrás
pai
commit
3ea16b3b86
4 arquivos alterados com 69 adições e 86 exclusões
  1. 35 32
      lib/data_augmentation.lua
  2. 5 12
      lib/pairwise_transform.lua
  3. 5 20
      lib/settings.lua
  4. 24 22
      train.lua

+ 35 - 32
lib/data_augmentation.lua

@@ -11,25 +11,44 @@ local function pcacov(x)
    local ce, cv = torch.symeig(c, 'V')
    return ce, cv
 end
-function data_augmentation.color_noise(src, factor)
+function data_augmentation.color_noise(src, p, factor)
    factor = factor or 0.1
-   local src, conversion = iproc.byte2float(src)
-   local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous()
-   local ce, cv = pcacov(src_t)
-   local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor)
-   
-   pca_space = torch.mm(src_t, cv):t():contiguous()
-   for i = 1, 3 do
-      pca_space[i]:mul(color_scale[i])
-   end
-   local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src)
-   dest[torch.lt(dest, 0.0)] = 0.0
-   dest[torch.gt(dest, 1.0)] = 1.0
+   if torch.uniform() < p then
+      local src, conversion = iproc.byte2float(src)
+      local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous()
+      local ce, cv = pcacov(src_t)
+      local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor)
+      
+      pca_space = torch.mm(src_t, cv):t():contiguous()
+      for i = 1, 3 do
+	 pca_space[i]:mul(color_scale[i])
+      end
+      local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src)
+      dest[torch.lt(dest, 0.0)] = 0.0
+      dest[torch.gt(dest, 1.0)] = 1.0
 
-   if conversion then
-      dest = iproc.float2byte(dest)
+      if conversion then
+	 dest = iproc.float2byte(dest)
+      end
+      return dest
+   else
+      return src
+   end
+end
+function data_augmentation.overlay(src, p)
+   if torch.uniform() < p then
+      local r = torch.uniform()
+      local src, conversion = iproc.byte2float(src)
+      src = src:contiguous()
+      local flip = data_augmentation.flip(src)
+      flip:mul(r):add(src * (1.0 - r))
+      if conversion then
+	 flip = iproc.float2byte(flip)
+      end
+      return flip
+   else
+      return src
    end
-   return dest
 end
 function data_augmentation.shift_1px(src)
    -- reducing the even/odd issue in nearest neighbor scaler.
@@ -76,20 +95,4 @@ function data_augmentation.flip(src)
    end
    return dest
 end
-function data_augmentation.overlay(src, p)
-   p = p or 0.25
-   if torch.uniform() < p then
-      local r = torch.uniform(0.2, 0.8)
-      local src, conversion = iproc.byte2float(src)
-      src = src:contiguous()
-      local flip = data_augmentation.flip(src)
-      flip:mul(r):add(src * (1.0 - r))
-      if conversion then
-	 flip = iproc.float2byte(flip)
-      end
-      return flip
-   else
-      return src
-   end
-end
 return data_augmentation

+ 5 - 12
lib/pairwise_transform.lua

@@ -6,9 +6,8 @@ local data_augmentation = require 'data_augmentation'
 local pairwise_transform = {}
 
 local function random_half(src, p)
-   p = p or 0.25
-   local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
-   if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
+   if torch.uniform() < p then
+      local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
       return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
    else
       return src
@@ -34,17 +33,11 @@ local function crop_if_large(src, max_size)
 end
 local function preprocess(src, crop_size, options)
    local dest = src
-   if options.random_half then
-      dest = random_half(dest)
-   end
+   dest = random_half(dest, options.random_half_rate)
    dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
    dest = data_augmentation.flip(dest)
-   if options.color_noise then
-      dest = data_augmentation.color_noise(dest)
-   end
-   if options.overlay then
-      dest = data_augmentation.overlay(dest)
-   end
+   dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
+   dest = data_augmentation.overlay(dest, options.random_overlay_rate)
    dest = data_augmentation.shift_1px(dest)
    
    return dest

+ 5 - 20
lib/settings.lua

@@ -26,19 +26,19 @@ cmd:option("-method", "scale", 'method to training (noise|scale)')
 cmd:option("-noise_level", 1, '(1|2)')
 cmd:option("-style", "art", '(art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')
-cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)')
-cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)')
+cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
+cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
+cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
 cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
-cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)')
 cmd:option("-crop_size", 46, 'crop size')
 cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
 cmd:option("-batch_size", 8, 'mini batch size')
 cmd:option("-epoch", 200, 'number of total epochs to run')
 cmd:option("-thread", -1, 'number of CPU threads')
 cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
-cmd:option("-validation_rate", 0.05, 'validation-set rate of data')
-cmd:option("-validation_crops", 80, 'number of region per image in validation')
+cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
+cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
 cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
 cmd:option("-active_cropping_tries", 10, 'active cropping tries')
 cmd:option("-nr_rate", 0.7, 'trade-off between reducing noise and erasing details (0.0-1.0)')
@@ -69,21 +69,6 @@ if not (settings.style == "art" or
 	settings.style == "photo") then
    error(string.format("unknown style: %s", settings.style))
 end
-if settings.random_half == 1 then
-   settings.random_half = true
-else
-   settings.random_half = false
-end
-if settings.color_noise == 1 then
-   settings.color_noise = true
-else
-   settings.color_noise = false
-end
-if settings.overlay == 1 then
-   settings.overlay = true
-else
-   settings.overlay = false
-end
 
 if settings.thread > 0 then
    torch.setnumthreads(tonumber(settings.thread))

+ 24 - 22
train.lua

@@ -85,20 +85,20 @@ local function transformer(x, is_validation, n, offset)
    x = compression.decompress(x)
    n = n or settings.batch_size;
    if is_validation == nil then is_validation = false end
-   local color_noise = nil 
-   local overlay = nil
+   local random_color_noise_rate = nil 
+   local random_overlay_rate = nil
    local active_cropping_rate = nil
    local active_cropping_tries = nil
    if is_validation then
       active_cropping_rate = 0
       active_cropping_tries = 0
-      color_noise = false
-      overlay = false
+      random_color_noise_rate = 0.0
+      random_overlay_rate = 0.0
    else
       active_cropping_rate = settings.active_cropping_rate
       active_cropping_tries = settings.active_cropping_tries
-      color_noise = settings.color_noise
-      overlay = settings.overlay
+      random_color_noise_rate = settings.random_color_noise_rate
+      random_overlay_rate = settings.random_overlay_rate
    end
    
    if settings.method == "scale" then
@@ -106,13 +106,14 @@ local function transformer(x, is_validation, n, offset)
 				      settings.scale,
 				      settings.crop_size, offset,
 				      n,
-				      { color_noise = color_noise,
-					overlay = overlay,
-					random_half = settings.random_half,
-					max_size = settings.max_size,
-					active_cropping_rate = active_cropping_rate,
-					active_cropping_tries = active_cropping_tries,
-					rgb = (settings.color == "rgb")
+				      {
+					 random_half_rate = settings.random_half_rate,
+					 random_color_noise_rate = random_color_noise_rate,
+					 random_overlay_rate = random_overlay_rate,
+					 max_size = settings.max_size,
+					 active_cropping_rate = active_cropping_rate,
+					 active_cropping_tries = active_cropping_tries,
+					 rgb = (settings.color == "rgb")
 				      })
    elseif settings.method == "noise" then
       return pairwise_transform.jpeg(x,
@@ -120,15 +121,16 @@ local function transformer(x, is_validation, n, offset)
 				     settings.noise_level,
 				     settings.crop_size, offset,
 				     n,
-				     { color_noise = color_noise,
-				       overlay = overlay,
-				       random_half = settings.random_half,
-				       max_size = settings.max_size,
-				       jpeg_sampling_factors = settings.jpeg_sampling_factors,
-				       active_cropping_rate = active_cropping_rate,
-				       active_cropping_tries = active_cropping_tries,
-				       nr_rate = settings.nr_rate,
-				       rgb = (settings.color == "rgb")
+				     {
+					random_half_rate = settings.random_half_rate,
+					random_color_noise_rate = random_color_noise_rate,
+					random_overlay_rate = random_overlay_rate,
+					max_size = settings.max_size,
+					jpeg_sampling_factors = settings.jpeg_sampling_factors,
+					active_cropping_rate = active_cropping_rate,
+					active_cropping_tries = active_cropping_tries,
+					nr_rate = settings.nr_rate,
+					rgb = (settings.color == "rgb")
 				     })
    end
 end