nagadomi 9 years ago
parent
commit
b35a9ae7d7
5 changed files with 58 additions and 22 deletions
  1. 11 0
      lib/iproc.lua
  2. 1 1
      lib/minibatch_adam.lua
  3. 40 17
      lib/pairwise_transform.lua
  4. 1 1
      lib/settings.lua
  5. 5 3
      train.lua

+ 11 - 0
lib/iproc.lua

@@ -1,5 +1,6 @@
 local gm = require 'graphicsmagick'
 local gm = require 'graphicsmagick'
 local image = require 'image'
 local image = require 'image'
+
 local iproc = {}
 local iproc = {}
 
 
 function iproc.crop_mod4(src)
 function iproc.crop_mod4(src)
@@ -16,6 +17,15 @@ function iproc.crop(src, w1, h1, w2, h2)
    end
    end
    return dest
    return dest
 end
 end
+function iproc.crop_nocopy(src, w1, h1, w2, h2)
+   local dest
+   if src:dim() == 3 then
+      dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]
+   else -- dim == 2
+      dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]
+   end
+   return dest
+end
 function iproc.byte2float(src)
 function iproc.byte2float(src)
    local conversion = false
    local conversion = false
    local dest = src
    local dest = src
@@ -55,4 +65,5 @@ function iproc.padding(img, w1, w2, h1, h2)
    return image.warp(img, flow, "simple", false, "clamp")
    return image.warp(img, flow, "simple", false, "clamp")
 end
 end
 
 
+
 return iproc
 return iproc

+ 1 - 1
lib/minibatch_adam.lua

@@ -45,7 +45,7 @@ local function minibatch_adam(model, criterion,
       optim.adam(feval, parameters, config)
       optim.adam(feval, parameters, config)
       
       
       c = c + 1
       c = c + 1
-      if c % 10 == 0 then
+      if c % 20 == 0 then
 	 collectgarbage()
 	 collectgarbage()
       end
       end
    end
    end

+ 40 - 17
lib/pairwise_transform.lua

@@ -16,10 +16,19 @@ local function random_half(src, p)
    end
    end
 end
 end
 local function crop_if_large(src, max_size)
 local function crop_if_large(src, max_size)
+   local tries = 4
    if src:size(2) > max_size and src:size(3) > max_size then
    if src:size(2) > max_size and src:size(3) > max_size then
-      local yi = torch.random(0, src:size(2) - max_size)
-      local xi = torch.random(0, src:size(3) - max_size)
-      return iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
+      local rect
+      for i = 1, tries do
+	 local yi = torch.random(0, src:size(2) - max_size)
+	 local xi = torch.random(0, src:size(3) - max_size)
+	 rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
+	 -- ignore simple background
+	 if rect:float():std() >= 0 then
+	    break
+	 end
+      end
+      return rect
    else
    else
       return src
       return src
    end
    end
@@ -29,7 +38,7 @@ local function preprocess(src, crop_size, options)
    if options.random_half then
    if options.random_half then
       dest = random_half(dest)
       dest = random_half(dest)
    end
    end
-   dest = crop_if_large(dest, math.max(crop_size * 4, 512))
+   dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
    dest = data_augmentation.flip(dest)
    dest = data_augmentation.flip(dest)
    if options.color_noise then
    if options.color_noise then
       dest = data_augmentation.color_noise(dest)
       dest = data_augmentation.color_noise(dest)
@@ -52,7 +61,9 @@ local function active_cropping(x, y, size, p, tries)
       return xc, yc
       return xc, yc
    else
    else
       local samples = {}
       local samples = {}
-      local sum_mse = 0
+      local best_se = 0.0
+      local best_xc, best_yc
+      local m = torch.FloatTensor(x:size(1), size, size)
       for i = 1, tries do
       for i = 1, tries do
 	 local xi = torch.random(0, y:size(3) - (size + 1))
 	 local xi = torch.random(0, y:size(3) - (size + 1))
 	 local yi = torch.random(0, y:size(2) - (size + 1))
 	 local yi = torch.random(0, y:size(2) - (size + 1))
@@ -60,17 +71,14 @@ local function active_cropping(x, y, size, p, tries)
 	 local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
 	 local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
 	 local xcf = iproc.byte2float(xc)
 	 local xcf = iproc.byte2float(xc)
 	 local ycf = iproc.byte2float(yc)
 	 local ycf = iproc.byte2float(yc)
-	 local mse = (xcf - ycf):pow(2):mean()
-	 sum_mse = sum_mse + mse
-	 table.insert(samples, {xc = xc, yc = yc, mse = mse})
-      end
-      if sum_mse > 0 then
-	 table.sort(samples,
-		    function (a, b)
-		       return a.mse > b.mse
-		    end)
+	 local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
+	 if se >= best_se then
+	    best_xc = xcf
+	    best_yc = ycf
+	    best_se = se
+	 end
       end
       end
-      return samples[1].xc, samples[1].yc
+      return best_xc, best_yc
    end
    end
 end
 end
 function pairwise_transform.scale(src, scale, size, offset, n, options)
 function pairwise_transform.scale(src, scale, size, offset, n, options)
@@ -83,6 +91,7 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
       "SincFast",   -- 0.014095824314306
       "SincFast",   -- 0.014095824314306
       "Jinc",       -- 0.014244299255442
       "Jinc",       -- 0.014244299255442
    }
    }
+   local unstable_region_offset = 8
    local downscale_filter = filters[torch.random(1, #filters)]
    local downscale_filter = filters[torch.random(1, #filters)]
    local y = preprocess(src, size, options)
    local y = preprocess(src, size, options)
    assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
    assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
@@ -90,6 +99,13 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
    local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
    local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
 				     y:size(2) * down_scale, downscale_filter),
 				     y:size(2) * down_scale, downscale_filter),
 			 y:size(3), y:size(2))
 			 y:size(3), y:size(2))
+   x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
+		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
+   y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
+		  y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
+   assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
+   assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
+   
    local batch = {}
    local batch = {}
    for i = 1, n do
    for i = 1, n do
       local xc, yc = active_cropping(x, y,
       local xc, yc = active_cropping(x, y,
@@ -108,8 +124,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
    return batch
    return batch
 end
 end
 function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
 function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
+   local unstable_region_offset = 8
    local y = preprocess(src, size, options)
    local y = preprocess(src, size, options)
    local x = y
    local x = y
+
    for i = 1, #quality do
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
       x = gm.Image(x, "RGB", "DHW")
       x:format("jpeg")
       x:format("jpeg")
@@ -122,7 +140,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
       x:fromBlob(blob, len)
       x:fromBlob(blob, len)
       x = x:toTensor("byte", "RGB", "DHW")
       x = x:toTensor("byte", "RGB", "DHW")
    end
    end
-   -- TODO: use shift_1px after compression?
+   x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
+		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
+   y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
+		  y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
+   assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
+   assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
    
    
    local batch = {}
    local batch = {}
    for i = 1, n do
    for i = 1, n do
@@ -152,7 +175,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
 	 end
 	 end
       elseif level == 2 then
       elseif level == 2 then
 	 local r = torch.uniform()
 	 local r = torch.uniform()
-	 if torch.uniform() > 0.8 then
+	 if torch.uniform() > 0.9 then
 	    return pairwise_transform.jpeg_(src, {},
 	    return pairwise_transform.jpeg_(src, {},
 					    size, offset, n, options)
 					    size, offset, n, options)
 	 else
 	 else

+ 1 - 1
lib/settings.lua

@@ -32,7 +32,7 @@ cmd:option("-scale", 2.0, 'scale')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
 cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
 cmd:option("-crop_size", 128, 'crop size')
 cmd:option("-crop_size", 128, 'crop size')
-cmd:option("-max_size", -1, 'crop if image size larger then this value.')
+cmd:option("-max_size", 512, 'crop if image size larger then this value.')
 cmd:option("-batch_size", 2, 'mini batch size')
 cmd:option("-batch_size", 2, 'mini batch size')
 cmd:option("-epoch", 200, 'epoch')
 cmd:option("-epoch", 200, 'epoch')
 cmd:option("-thread", -1, 'number of CPU threads')
 cmd:option("-thread", -1, 'number of CPU threads')

+ 5 - 3
train.lua

@@ -91,7 +91,7 @@ local function transformer(x, is_validation, n, offset)
    local active_cropping_tries = nil
    local active_cropping_tries = nil
    
    
    if is_validation then
    if is_validation then
-      active_cropping_rate = 0.0
+      active_cropping_rate = 0
       active_cropping_tries = 0
       active_cropping_tries = 0
       color_noise = false
       color_noise = false
       overlay = false
       overlay = false
@@ -110,6 +110,7 @@ local function transformer(x, is_validation, n, offset)
 				      { color_noise = color_noise,
 				      { color_noise = color_noise,
 					overlay = overlay,
 					overlay = overlay,
 					random_half = settings.random_half,
 					random_half = settings.random_half,
+					max_size = settings.max_size,
 					active_cropping_rate = active_cropping_rate,
 					active_cropping_rate = active_cropping_rate,
 					active_cropping_tries = active_cropping_tries,
 					active_cropping_tries = active_cropping_tries,
 					rgb = (settings.color == "rgb")
 					rgb = (settings.color == "rgb")
@@ -122,10 +123,11 @@ local function transformer(x, is_validation, n, offset)
 				     n,
 				     n,
 				     { color_noise = color_noise,
 				     { color_noise = color_noise,
 				       overlay = overlay,
 				       overlay = overlay,
-				       active_cropping_rate = active_cropping_rate,
-				       active_cropping_tries = active_cropping_tries,
 				       random_half = settings.random_half,
 				       random_half = settings.random_half,
+				       max_size = settings.max_size,
 				       jpeg_sampling_factors = settings.jpeg_sampling_factors,
 				       jpeg_sampling_factors = settings.jpeg_sampling_factors,
+				       active_cropping_rate = active_cropping_rate,
+				       active_cropping_tries = active_cropping_tries,
 				       rgb = (settings.color == "rgb")
 				       rgb = (settings.color == "rgb")
 				     })
 				     })
    end
    end