Forráskód Böngészése

perfomance tuning

nagadomi 8 éve
szülő
commit
bdaca16c67
6 módosított fájl, 68 hozzáadás és 19 törlés
  1. 11 1
      lib/iproc.lua
  2. 14 4
      lib/pairwise_transform_user.lua
  3. 20 7
      lib/pairwise_transform_utils.lua
  4. 2 0
      lib/settings.lua
  5. 0 3
      lib/srcnn.lua
  6. 21 4
      train.lua

+ 11 - 1
lib/iproc.lua

@@ -178,7 +178,7 @@ local function rotate_with_warp(src, dst, theta, mode)
   flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
   flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
   dst:resizeAs(src)
-  return image.warp(dst, src, flow, mode, true, 'pad')
+  return image.warp(dst, src, flow, mode, true, 'clamp')
 end
 function iproc.rotate(src, theta)
    local conversion
@@ -212,6 +212,16 @@ function iproc.gaussian2d(kernel_size, sigma)
    kernel:div(kernel:sum())
    return kernel
 end
+function iproc.rgb2y(src)
+   local conversion
+   src, conversion = iproc.byte2float(src)
+   local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
+   dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
+end
 
 local function test_conversion()
    local a = torch.linspace(0, 255, 256):float():div(255.0)

+ 14 - 4
lib/pairwise_transform_user.lua

@@ -13,8 +13,18 @@ function pairwise_transform.user(x, y, size, offset, n, options)
    x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
    assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
    local batch = {}
-   local lowres_y = pairwise_utils.low_resolution(y)
-   local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
+   local lowres_y = nil
+   local xs ={x}
+   local ys = {y}
+   local ls = {}
+
+   if options.active_cropping_rate > 0 then
+      lowres_y = pairwise_utils.low_resolution(y)
+   end
+   if options.pairwise_flip then
+      xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
+   end
+   assert(#xs == #ys)
    for i = 1, n do
       local t = (i % #xs) + 1
       local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
@@ -24,8 +34,8 @@ function pairwise_transform.user(x, y, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+	 yc = iproc.rgb2y(yc)
+	 xc = iproc.rgb2y(xc)
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end

+ 20 - 7
lib/pairwise_transform_utils.lua

@@ -164,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
 
    for j = 1, 2 do
       -- TTA
-      local xi, yi, ri
+      local xi, yi, ri, ni
       if j == 1 then
 	 xi = x
 	 ni = x_noise
@@ -176,7 +176,9 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
 	    ni = x_noise:transpose(2, 3):contiguous()
 	 end
 	 yi = y:transpose(2, 3):contiguous()
-	 ri = lowres_y:transpose(2, 3):contiguous()
+	 if lowres_y then
+	    ri = lowres_y:transpose(2, 3):contiguous()
+	 end
       end
       local xv = iproc.vflip(xi)
       local nv
@@ -184,34 +186,45 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
 	 nv = iproc.vflip(ni)
       end
       local yv = iproc.vflip(yi)
-      local rv = iproc.vflip(ri)
+      local rv
+      if ri then
+	 rv = iproc.vflip(ri)
+      end
       table.insert(xs, xi)
       if ni then
 	 table.insert(ns, ni)
       end
       table.insert(ys, yi)
-      table.insert(ls, ri)
+      if ri then
+	 table.insert(ls, ri)
+      end
 
       table.insert(xs, xv)
       if nv then
 	 table.insert(ns, nv)
       end
       table.insert(ys, yv)
-      table.insert(ls, rv)
+      if rv then
+	 table.insert(ls, rv)
+      end
 
       table.insert(xs, iproc.hflip(xi))
       if ni then
 	 table.insert(ns, iproc.hflip(ni))
       end
       table.insert(ys, iproc.hflip(yi))
-      table.insert(ls, iproc.hflip(ri))
+      if ri then
+	 table.insert(ls, iproc.hflip(ri))
+      end
 
       table.insert(xs, iproc.hflip(xv))
       if nv then
 	 table.insert(ns, iproc.hflip(nv))
       end
       table.insert(ys, iproc.hflip(yv))
-      table.insert(ls, iproc.hflip(rv))
+      if rv then
+	 table.insert(ls, iproc.hflip(rv))
+      end
    end
    return xs, ys, ls, ns
 end

+ 2 - 0
lib/settings.lua

@@ -46,6 +46,7 @@ cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwi
 cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
 cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
 cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
+cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
 cmd:option("-scale", 2.0, 'scale factor (2)')
 cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
 cmd:option("-crop_size", 48, 'crop size')
@@ -91,6 +92,7 @@ to_bool(settings, "plot")
 to_bool(settings, "save_history")
 to_bool(settings, "use_transparent_png")
 to_bool(settings, "pairwise_y_binary")
+to_bool(settings, "pairwise_flip")
 
 if settings.plot then
    require 'gnuplot'

+ 0 - 3
lib/srcnn.lua

@@ -466,9 +466,6 @@ function srcnn.fcn_v1(backend, ch)
    model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))
    model:add(nn.Dropout(0.5, false, true))
-   model:add(SpatialConvolution(backend, 256, 256, 1, 1, 1, 1, 0, 0))
-   model:add(nn.LeakyReLU(0.1, true))
-   model:add(nn.Dropout(0.5, false, true))
 
    model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
    model:add(nn.LeakyReLU(0.1, true))

+ 21 - 4
train.lua

@@ -175,19 +175,36 @@ local function transform_pool_init(has_resize, offset)
 						    settings.crop_size, offset,
 						    n, conf)
 	    elseif settings.method == "user" then
+	       if is_validation == nil then is_validation = false end
+	       local rotate_rate = nil 
+	       local scale_rate = nil
+	       local negate_rate = nil
+	       local negate_x_rate = nil
+	       if is_validation then
+		  rotate_rate = 0
+		  scale_rate = 0
+		  negate_rate = 0
+		  negate_x_rate = 0
+	       else
+		  rotate_rate = settings.random_pairwise_rotate_rate
+		  scale_rate = settings.random_pairwise_scale_rate
+		  negate_rate = settings.random_pairwise_negate_rate
+		  negate_x_rate = settings.random_pairwise_negate_x_rate
+	       end
 	       local conf = tablex.update({
 		     max_size = settings.max_size,
 		     active_cropping_rate = active_cropping_rate,
 		     active_cropping_tries = active_cropping_tries,
-		     random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
+		     random_pairwise_rotate_rate = rotate_rate,
 		     random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
 		     random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
-		     random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
+		     random_pairwise_scale_rate = scale_rate,
 		     random_pairwise_scale_min = settings.random_pairwise_scale_min,
 		     random_pairwise_scale_max = settings.random_pairwise_scale_max,
-		     random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
-		     random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
+		     random_pairwise_negate_rate = negate_rate,
+		     random_pairwise_negate_x_rate = negate_x_rate,
 		     pairwise_y_binary = settings.pairwise_y_binary,
+		     pairwise_flip = settings.pairwise_flip,
 		     rgb = (settings.color == "rgb")}, meta)
 	       return pairwise_transform.user(x, y,
 					      settings.crop_size, offset,