فهرست منبع

performance tuning

nagadomi 8 سال پیش
والد
کامیت
f7e83e4465

+ 1 - 4
lib/pairwise_transform_jpeg.lua

@@ -30,10 +30,7 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
    assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
    
    local batch = {}
-   local lowres_y = gm.Image(y, "RGB", "DHW"):
-      size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
-      size(y:size(3), y:size(2), "Box"):
-      toTensor(t, "RGB", "DHW")
+   local lowres_y = pairwise_utils.low_resolution(y)
 
    local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
    for i = 1, n do

+ 1 - 4
lib/pairwise_transform_jpeg_scale.lua

@@ -91,10 +91,7 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off
       assert(x:size(1) == y:size(1) and x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
    end
    local batch = {}
-   local lowres_y = gm.Image(y, "RGB", "DHW"):
-      size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
-      size(y:size(3), y:size(2), "Box"):
-      toTensor(t, "RGB", "DHW")
+   local lowres_y = pairwise_utils.low_resolution(y)
    local x_noise = add_jpeg_noise(x, style, noise_level, options)
 
    local xs, ys, ls, ns = pairwise_utils.flip_augmentation(x, y, lowres_y, x_noise)

+ 1 - 4
lib/pairwise_transform_scale.lua

@@ -37,10 +37,7 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
       assert(x:size(1) == y:size(1) and x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
    end
    local batch = {}
-   local lowres_y = gm.Image(y, "RGB", "DHW"):
-      size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
-      size(y:size(3), y:size(2), "Box"):
-      toTensor(t, "RGB", "DHW")
+   local lowres_y = pairwise_utils.low_resolution(y)
    local xs, ys, ls, _ = pairwise_utils.flip_augmentation(x, y, lowres_y)
    for i = 1, n do
       local t = (i % #xs) + 1

+ 2 - 4
lib/pairwise_transform_user.lua

@@ -36,10 +36,7 @@ function pairwise_transform.user(x, y, size, offset, n, options)
    x, y = crop_if_large(x, y, scale_y, options.max_size, scale_y)
    assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
    local batch = {}
-   local lowres_y = gm.Image(y, "RGB", "DHW"):
-      size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
-      size(y:size(3), y:size(2), "Box"):
-      toTensor(t, "RGB", "DHW")
+   local lowres_y = pairwise_utils.low_resolution(y)
    local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
    for i = 1, n do
       local t = (i % #xs) + 1
@@ -55,6 +52,7 @@ function pairwise_transform.user(x, y, size, offset, n, options)
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end
+
    return batch
 end
 return pairwise_transform

+ 60 - 14
lib/pairwise_transform_utils.lua

@@ -1,5 +1,7 @@
 require 'image'
+require 'cunn'
 local iproc = require 'iproc'
+local gm = require 'graphicsmagick'
 local data_augmentation = require 'data_augmentation'
 local pairwise_transform_utils = {}
 
@@ -76,24 +78,26 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
       local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
       return xc, yc
    else
-      local best_se = 0.0
-      local best_xi, best_yi
-      local m = torch.LongTensor(y:size(1), size, size)
-      local targets = {}
+      local xcs = torch.LongTensor(tries, y:size(1), size, size)
+      local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
+      local rects = {}
+      local r = torch.LongTensor(2, tries)
+      r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
+      r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
       for i = 1, tries do
-	 local xi = torch.random(1, x:size(3) - (size + 1)) * scale
-	 local yi = torch.random(1, x:size(2) - (size + 1)) * scale
+	 local xi = r[1][i]
+	 local yi = r[2][i]
 	 local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
 	 local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
-	 m:copy(xc:long()):csub(lc:long())
-	 m:cmul(m)
-	 local se = m:sum()
-	 if se >= best_se then
-	    best_xi = xi
-	    best_yi = yi
-	    best_se = se
-	 end
+	 xcs[i]:copy(xc)
+	 lcs[i]:copy(lc)
+	 rects[i] = {xi, yi}
       end
+      xcs:csub(lcs)
+      xcs:cmul(xcs)
+      local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
+      local best_xi = rects[l[1][1]][1]
+      local best_yi = rects[l[1][1]][2]
       local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
       local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
       return xc, yc
@@ -158,5 +162,47 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
    end
    return xs, ys, ls, ns
 end
+local function lowres_model()
+   local seq = nn.Sequential()
+   seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
+   seq:add(nn.SpatialUpSamplingNearest(2))
+   return seq:cuda()
+end
+local g_lowres_model = nil
+local g_lowres_gpu = nil
+function pairwise_transform_utils.low_resolution(src)
+   g_lowres_model = g_lowres_model or lowres_model()
+   if g_lowres_gpu == nil then
+      --benchmark
+      local gpu_time = sys.clock()
+      for i = 1, 10 do
+	 g_lowres_model:forward(src:cuda()):byte()
+      end
+      gpu_time = sys.clock() - gpu_time
+
+      local cpu_time = sys.clock()
+      for i = 1, 10 do
+	 gm.Image(src, "RGB", "DHW"):
+	    size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+	    size(src:size(3), src:size(2), "Box"):
+	    toTensor("byte", "RGB", "DHW")
+      end
+      cpu_time = sys.clock() - cpu_time
+      --print(gpu_time, cpu_time)
+      if gpu_time < cpu_time then
+	 g_lowres_gpu = true
+      else
+	 g_lowres_gpu = false
+      end
+   end
+   if g_lowres_gpu then
+      return g_lowres_model:forward(src:cuda()):byte()
+   else
+      return gm.Image(src, "RGB", "DHW"):
+	 size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+	 size(src:size(3), src:size(2), "Box"):
+	    toTensor("byte", "RGB", "DHW")
+   end
+end
 
 return pairwise_transform_utils