Преглед изворни кода

Optimize flip in user method when -patchs is small

nagadomi пре 8 година
родитељ
комит
c6e3a68974
2 измењених фајлова са 36 додато и 2 уклоњено
  1. 30 0
      lib/data_augmentation.lua
  2. 6 2
      lib/pairwise_transform_user.lua

+ 30 - 0
lib/data_augmentation.lua

@@ -139,6 +139,36 @@ function data_augmentation.pairwise_negate_x(x, y, p)
       return x, y
    end
 end
+function data_augmentation.pairwise_flip(x, y)
+   local flip = torch.random(1, 4)
+   local tr = torch.random(1, 2)
+   local x, conversion = iproc.byte2float(x)
+   y = iproc.byte2float(y)
+   x = x:contiguous()
+   y = y:contiguous()
+   if tr == 1 then
+      -- pass
+   elseif tr == 2 then
+      x = x:transpose(2, 3):contiguous()
+      y = y:transpose(2, 3):contiguous()
+   end
+   if flip == 1 then
+      x = iproc.hflip(x)
+      y = iproc.hflip(y)
+   elseif flip == 2 then
+      x = iproc.vflip(x)
+      y = iproc.vflip(y)
+   elseif flip == 3 then
+      x = iproc.hflip(iproc.vflip(x))
+      y = iproc.hflip(iproc.vflip(y))
+   elseif flip == 4 then
+   end
+   if conversion then
+      x = iproc.float2byte(x)
+      y = iproc.float2byte(y)
+   end
+   return x, y
+end
 function data_augmentation.shift_1px(src)
    -- reducing the even/odd issue in nearest neighbor scaler.
    local direction = torch.random(1, 4)

+ 6 - 2
lib/pairwise_transform_user.lua

@@ -1,4 +1,5 @@
 local pairwise_utils = require 'pairwise_transform_utils'
+local data_augmentation = require 'data_augmentation'
 local iproc = require 'iproc'
 local gm = {}
 gm.Image = require 'graphicsmagick.Image'
@@ -21,12 +22,15 @@ function pairwise_transform.user(x, y, size, offset, n, options)
    if options.active_cropping_rate > 0 then
       lowres_y = pairwise_utils.low_resolution(y)
    end
-   if options.pairwise_flip then
+   if options.pairwise_flip and n == 1 then
+      xs[1], ys[1] = data_augmentation.pairwise_flip(xs[1], ys[1])
+   elseif options.pairwise_flip then
       xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
    end
    assert(#xs == #ys)
+   local perm = torch.randperm(#xs)
    for i = 1, n do
-      local t = (i % #xs) + 1
+      local t = perm[(i % #xs) + 1]
       local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
 						    options.active_cropping_rate,
 						    options.active_cropping_tries)