Parcourir la source

Add -random_unsharp_mask_rate option for photo

nagadomi il y a 9 ans
Parent
commit
c72ec3112b
4 fichiers modifiés avec 29 ajouts et 1 suppressions
  1. 20 0
      lib/data_augmentation.lua
  2. 6 1
      lib/pairwise_transform.lua
  3. 1 0
      lib/settings.lua
  4. 2 0
      train.lua

+ 20 - 0
lib/data_augmentation.lua

@@ -1,5 +1,6 @@
 require 'image'
 local iproc = require 'iproc'
+local gm = require 'graphicsmagick'
 
 local data_augmentation = {}
 
@@ -50,6 +51,25 @@ function data_augmentation.overlay(src, p)
       return src
    end
 end
+function data_augmentation.unsharp_mask(src, p)
+   if torch.uniform() < p then
+      local radius = 0 -- auto
+      local sigma = torch.uniform(0.7, 1.4)
+      local amount = torch.uniform(0.5, 1.0)
+      local threshold = torch.uniform(0.0, 0.05)
+      local unsharp = gm.Image(src, "RGB", "DHW"):
+	 unsharpMask(radius, sigma, amount, threshold):
+	 toTensor("float", "RGB", "DHW")
+      
+      if src:type() == "torch.ByteTensor" then
+	 return iproc.float2byte(unsharp)
+      else
+	 return unsharp
+      end
+   else
+      return src
+   end
+end
 function data_augmentation.shift_1px(src)
    -- reducing the even/odd issue in nearest neighbor scaler.
    local direction = torch.random(1, 4)

+ 6 - 1
lib/pairwise_transform.lua

@@ -7,7 +7,7 @@ local pairwise_transform = {}
 
 local function random_half(src, p)
    if torch.uniform() < p then
-      local filter = ({"Box","Box","Blackman","Sinc","Lanczos"})[torch.random(1, 5)]
+      local filter = ({"Box","Box","Blackman","Sinc","Lanczos", "Catrom"})[torch.random(1, 6)]
       return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
    else
       return src
@@ -38,6 +38,7 @@ local function preprocess(src, crop_size, options)
    dest = data_augmentation.flip(dest)
    dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
    dest = data_augmentation.overlay(dest, options.random_overlay_rate)
+   dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
    dest = data_augmentation.shift_1px(dest)
    
    return dest
@@ -81,6 +82,7 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
       --"Hermite",    -- 0.013850225205266
       "Sinc",   -- 0.014095824314306
       "Lanczos",       -- 0.014244299255442
+      "Catrom"
    }
    local unstable_region_offset = 8
    local downscale_filter = filters[torch.random(1, #filters)]
@@ -211,6 +213,8 @@ function pairwise_transform.test_jpeg(src)
    local options = {random_color_noise_rate = 0.5,
 		    random_half_rate = 0.5,
 		    random_overlay_rate = 0.5,
+		    random_unsharp_mask_rate = 0.5,
+		    jpeg_chroma_subsampling_rate = 0.5,
 		    nr_rate = 1.0,
 		    active_cropping_rate = 0.5,
 		    active_cropping_tries = 10,
@@ -233,6 +237,7 @@ function pairwise_transform.test_scale(src)
    local options = {random_color_noise_rate = 0.5,
 		    random_half_rate = 0.5,
 		    random_overlay_rate = 0.5,
+		    random_unsharp_mask_rate = 0.5,
 		    active_cropping_rate = 0.5,
 		    active_cropping_tries = 10,
 		    max_size = 256,

+ 1 - 0
lib/settings.lua

@@ -30,6 +30,7 @@ cmd:option("-color", 'rgb', '(y|rgb)')
 cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
 cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
 cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
+cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
 cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-crop_size", 46, 'crop size')

+ 2 - 0
train.lua

@@ -110,6 +110,7 @@ local function transformer(x, is_validation, n, offset)
 					 random_half_rate = settings.random_half_rate,
 					 random_color_noise_rate = random_color_noise_rate,
 					 random_overlay_rate = random_overlay_rate,
+					 random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
 					 max_size = settings.max_size,
 					 active_cropping_rate = active_cropping_rate,
 					 active_cropping_tries = active_cropping_tries,
@@ -125,6 +126,7 @@ local function transformer(x, is_validation, n, offset)
 					random_half_rate = settings.random_half_rate,
 					random_color_noise_rate = random_color_noise_rate,
 					random_overlay_rate = random_overlay_rate,
+					random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
 					max_size = settings.max_size,
 					jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
 					active_cropping_rate = active_cropping_rate,