Explorar el Código

Add -max_training_image_size option

nagadomi hace 9 años
padre
commit
8d6451a51b
Se han modificado 2 ficheros con 24 adiciones y 1 borrados
  1. 22 0
      convert_data.lua
  2. 2 1
      lib/settings.lua

+ 22 - 0
convert_data.lua

@@ -8,6 +8,25 @@ local settings = require 'settings'
 local image_loader = require 'image_loader'
 local iproc = require 'iproc'
 
+local function crop_if_large(src, max_size)
+   local tries = 4
+   if src:size(2) >= max_size and src:size(3) >= max_size then
+      local rect
+      for i = 1, tries do
+	 local yi = torch.random(0, src:size(2) - max_size)
+	 local xi = torch.random(0, src:size(3) - max_size)
+	 rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
+	 -- ignore simple background
+	 if rect:float():std() >= 0 then
+	    break
+	 end
+      end
+      return rect
+   else
+      return src
+   end
+end
+
 local function load_images(list)
    local MARGIN = 32
    local lines = utils.split(file.read(list), "\n")
@@ -18,6 +37,9 @@ local function load_images(list)
       if alpha then
 	 io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
       else
+	 if settings.max_training_image_size > 0 then
+	    im = crop_if_large(im, settings.max_training_image_size)
+	 end
 	 im = iproc.crop_mod4(im)
 	 local scale = 1.0
 	 if settings.random_half_rate > 0.0 then

+ 2 - 1
lib/settings.lua

@@ -34,7 +34,7 @@ cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp ma
 cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.0005, 'learning rate for adam')
 cmd:option("-crop_size", 46, 'crop size')
-cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
+cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly')
 cmd:option("-batch_size", 8, 'mini batch size')
 cmd:option("-patches", 16, 'number of patch samples')
 cmd:option("-inner_epoch", 4, 'number of inner epochs')
@@ -51,6 +51,7 @@ cmd:option("-plot", 0, 'plot loss chart(0|1)')
 cmd:option("-downsampling_filters", "Box,Catrom", '(comma separated)downsampling filters for 2x scale training. (Point,Box,Triangle,Hermite,Hanning,Hamming,Blackman,Gaussian,Quadratic,Cubic,Catrom,Mitchell,Lanczos,Bessel,Sinc)')
 cmd:option("-gamma_correction", 0, 'Resizing with colorspace correction(sRGB:gamma 2.2) in scale training (0|1)')
 cmd:option("-upsampling_filter", "Box", 'upsampling filter for 2x scale training (dev)')
+cmd:option("-max_training_image_size", -1, 'if training image is larger than N, image will be crop randomly when data converting')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then