Kaynağa Gözat

Fix a bug that -nr_rate is not used

nagadomi 9 yıl önce
ebeveyn
işleme
9514027f65
1 değiştirilmiş dosya ile 21 ekleme ve 7 silme
  1. 21 7
      lib/pairwise_transform_jpeg_scale.lua

+ 21 - 7
lib/pairwise_transform_jpeg_scale.lua

@@ -38,6 +38,7 @@ local function add_jpeg_noise(src, style, level, options)
 	    local quality1 = torch.random(52, 70)
 	    local quality2 = quality1 - torch.random(5, 15)
 	    local quality3 = quality1 - torch.random(15, 25)
+
 	    return add_jpeg_noise_(src, {quality1, quality2, quality3}, options)
 	 end
       else
@@ -80,7 +81,6 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off
 	 x = small
       end
    end
-   x = add_jpeg_noise(x, style, noise_level, options)
    local scale_inner = scale
    if options.x_upsampling then
       scale_inner = 1
@@ -101,9 +101,9 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off
       size(y:size(3), y:size(2), "Box"):
       toTensor(t, "RGB", "DHW")
    local xs = {}
+   local ns = {}
    local ys = {}
    local lowreses = {}
-
    for j = 1, 2 do
       -- TTA
       local xi, yi, ri
@@ -134,11 +134,25 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off
    end
    for i = 1, n do
       local t = (i % #xs) + 1
-      local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], lowreses[t],
-						    size,
-						    scale_inner,
-						    options.active_cropping_rate,
-						    options.active_cropping_tries)
+      local xc, yc
+      if torch.uniform() < options.nr_rate then
+	 -- scale + noise reduction
+	 if not ns[t] then
+	    ns[t] = add_jpeg_noise(xs[t], style, noise_level, options)
+	 end
+	 xc, yc = pairwise_utils.active_cropping(ns[t], ys[t], lowreses[t],
+						 size,
+						 scale_inner,
+						 options.active_cropping_rate,
+						 options.active_cropping_tries)
+      else
+	 -- scale
+	 xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], lowreses[t],
+						 size,
+						 scale_inner,
+						 options.active_cropping_rate,
+						 options.active_cropping_tries)
+      end
       xc = iproc.byte2float(xc)
       yc = iproc.byte2float(yc)
       if options.rgb then