nagadomi 9 年之前
父节点
当前提交
9d3e1a241e
共有 4 个文件被更改,包括 87 次插入93 次删除
  1. 13 10
      lib/pairwise_transform_jpeg.lua
  2. 13 50
      lib/pairwise_transform_jpeg_scale.lua
  3. 2 33
      lib/pairwise_transform_scale.lua
  4. 59 0
      lib/pairwise_transform_utils.lua

+ 13 - 10
lib/pairwise_transform_jpeg.lua

@@ -7,18 +7,18 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
    local unstable_region_offset = 8
    local y = pairwise_utils.preprocess(src, size, options)
    local x = y
+   local factors
 
+   if torch.uniform() < options.jpeg_chroma_subsampling_rate then
+      -- YUV 420
+      factors = {2.0, 1.0, 1.0}
+   else
+      -- YUV 444
+      factors = {1.0, 1.0, 1.0}
+   end
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
-      x:format("jpeg"):depth(8)
-      if torch.uniform() < options.jpeg_chroma_subsampling_rate then
-	 -- YUV 420
-	 x:samplingFactors({2.0, 1.0, 1.0})
-      else
-	 -- YUV 444
-	 x:samplingFactors({1.0, 1.0, 1.0})
-      end
-      local blob, len = x:toBlob(quality[i])
+      local blob, len = x:format("jpeg"):depth(8):samplingFactors(factors):toBlob(quality[i])
       x:fromBlob(blob, len)
       x = x:toTensor("byte", "RGB", "DHW")
    end
@@ -34,8 +34,11 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
       size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
       size(y:size(3), y:size(2), "Box"):
       toTensor(t, "RGB", "DHW")
+
+   local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
    for i = 1, n do
-      local xc, yc = pairwise_utils.active_cropping(x, y, lowres_y, size, 1,
+      local t = (i % #xs) + 1
+      local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, 1,
 						    options.active_cropping_rate,
 						    options.active_cropping_tries)
       xc = iproc.byte2float(xc)

+ 13 - 50
lib/pairwise_transform_jpeg_scale.lua

@@ -4,17 +4,17 @@ local gm = require 'graphicsmagick'
 local pairwise_transform = {}
 
 local function add_jpeg_noise_(x, quality, options)
+   local factors
+   if torch.uniform() < options.jpeg_chroma_subsampling_rate then
+      -- YUV 420
+      factors = {2.0, 1.0, 1.0}
+   else
+      -- YUV 444
+      factors = {1.0, 1.0, 1.0}
+   end
    for i = 1, #quality do
       x = gm.Image(x, "RGB", "DHW")
-      x:format("jpeg"):depth(8)
-      if torch.uniform() < options.jpeg_chroma_subsampling_rate then
-	 -- YUV 420
-	 x:samplingFactors({2.0, 1.0, 1.0})
-      else
-	 -- YUV 444
-	 x:samplingFactors({1.0, 1.0, 1.0})
-      end
-      local blob, len = x:toBlob(quality[i])
+      local blob, len = x:format("jpeg"):depth(8):samplingFactors(factors):toBlob(quality[i])
       x:fromBlob(blob, len)
       x = x:toTensor("byte", "RGB", "DHW")
    end
@@ -89,59 +89,22 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off
       size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
       size(y:size(3), y:size(2), "Box"):
       toTensor(t, "RGB", "DHW")
-   local xs = {}
-   local ns = {}
-   local ys = {}
    local x_noise = add_jpeg_noise(x, style, noise_level, options)
-   local lowreses = {}
-   for j = 1, 2 do
-      -- TTA
-      local xi, yi, ri
-      if j == 1 then
-	 xi = x
-	 ni = x_noise
-	 yi = y
-	 ri = lowres_y
-      else
-	 xi = x:transpose(2, 3):contiguous()
-	 ni = x_noise:transpose(2, 3):contiguous()
-	 yi = y:transpose(2, 3):contiguous()
-	 ri = lowres_y:transpose(2, 3):contiguous()
-      end
-      local xv = image.vflip(xi)
-      local nv = image.vflip(ni)
-      local yv = image.vflip(yi)
-      local rv = image.vflip(ri)
-      table.insert(xs, xi)
-      table.insert(ns, ni)
-      table.insert(ys, yi)
-      table.insert(lowreses, ri)
-      table.insert(xs, xv)
-      table.insert(ns, nv)
-      table.insert(ys, yv)
-      table.insert(lowreses, rv)
-      table.insert(xs, image.hflip(xi))
-      table.insert(ns, image.hflip(ni))
-      table.insert(ys, image.hflip(yi))
-      table.insert(lowreses, image.hflip(ri))
-      table.insert(xs, image.hflip(xv))
-      table.insert(ns, image.hflip(nv))
-      table.insert(ys, image.hflip(yv))
-      table.insert(lowreses, image.hflip(rv))
-   end
+
+   local xs, ys, ls, ns = pairwise_utils.flip_augmentation(x, y, lowres_y, x_noise)
    for i = 1, n do
       local t = (i % #xs) + 1
       local xc, yc
       if torch.uniform() < options.nr_rate then
 	 -- scale + noise reduction
-	 xc, yc = pairwise_utils.active_cropping(ns[t], ys[t], lowreses[t],
+	 xc, yc = pairwise_utils.active_cropping(ns[t], ys[t], ls[t],
 						 size,
 						 scale_inner,
 						 options.active_cropping_rate,
 						 options.active_cropping_tries)
       else
 	 -- scale
-	 xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], lowreses[t],
+	 xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t],
 						 size,
 						 scale_inner,
 						 options.active_cropping_rate,

+ 2 - 33
lib/pairwise_transform_scale.lua

@@ -41,41 +41,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
       size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
       size(y:size(3), y:size(2), "Box"):
       toTensor(t, "RGB", "DHW")
-   local xs = {}
-   local ys = {}
-   local lowreses = {}
-
-   for j = 1, 2 do
-      -- TTA
-      local xi, yi, ri
-      if j == 1 then
-	 xi = x
-	 yi = y
-	 ri = lowres_y
-      else
-	 xi = x:transpose(2, 3):contiguous()
-	 yi = y:transpose(2, 3):contiguous()
-	 ri = lowres_y:transpose(2, 3):contiguous()
-      end
-      local xv = image.vflip(xi)
-      local yv = image.vflip(yi)
-      local rv = image.vflip(ri)
-      table.insert(xs, xi)
-      table.insert(ys, yi)
-      table.insert(lowreses, ri)
-      table.insert(xs, xv)
-      table.insert(ys, yv)
-      table.insert(lowreses, rv)
-      table.insert(xs, image.hflip(xi))
-      table.insert(ys, image.hflip(yi))
-      table.insert(lowreses, image.hflip(ri))
-      table.insert(xs, image.hflip(xv))
-      table.insert(ys, image.hflip(yv))
-      table.insert(lowreses, image.hflip(rv))
-   end
+   local xs, ys, ls, _ = pairwise_utils.flip_augmentation(x, y, lowres_y)
    for i = 1, n do
       local t = (i % #xs) + 1
-      local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], lowreses[t],
+      local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t],
 						    size,
 						    scale_inner,
 						    options.active_cropping_rate,

+ 59 - 0
lib/pairwise_transform_utils.lua

@@ -98,5 +98,64 @@ function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p
       return xc, yc
    end
 end
+function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
+   local xs = {}
+   local ns = {}
+   local ys = {}
+   local ls = {}
+
+   for j = 1, 2 do
+      -- TTA
+      local xi, yi, ri
+      if j == 1 then
+	 xi = x
+	 ni = x_noise
+	 yi = y
+	 ri = lowres_y
+      else
+	 xi = x:transpose(2, 3):contiguous()
+	 if x_noise then
+	    ni = x_noise:transpose(2, 3):contiguous()
+	 end
+	 yi = y:transpose(2, 3):contiguous()
+	 ri = lowres_y:transpose(2, 3):contiguous()
+      end
+      local xv = image.vflip(xi)
+      local nv
+      if x_noise then
+	 nv = image.vflip(ni)
+      end
+      local yv = image.vflip(yi)
+      local rv = image.vflip(ri)
+      table.insert(xs, xi)
+      if ni then
+	 table.insert(ns, ni)
+      end
+      table.insert(ys, yi)
+      table.insert(ls, ri)
+
+      table.insert(xs, xv)
+      if nv then
+	 table.insert(ns, nv)
+      end
+      table.insert(ys, yv)
+      table.insert(ls, rv)
+
+      table.insert(xs, image.hflip(xi))
+      if ni then
+	 table.insert(ns, image.hflip(ni))
+      end
+      table.insert(ys, image.hflip(yi))
+      table.insert(ls, image.hflip(ri))
+
+      table.insert(xs, image.hflip(xv))
+      if nv then
+	 table.insert(ns, image.hflip(nv))
+      end
+      table.insert(ys, image.hflip(yv))
+      table.insert(ls, image.hflip(rv))
+   end
+   return xs, ys, ls, ns
+end
 
 return pairwise_transform_utils