Bläddra i källkod

Add trade-off parameter for noise reduction

nagadomi 9 år sedan
förälder
incheckning
c773e18e59
3 ändrade filer med 31 tillägg och 46 borttagningar
  1. 29 45
      lib/pairwise_transform.lua
  2. 1 0
      lib/settings.lua
  3. 1 1
      train.lua

+ 29 - 45
lib/pairwise_transform.lua

@@ -158,68 +158,52 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
 	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
 	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
       end
-      table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
+      if torch.uniform() < options.nr_rate then
+	 -- reductiong noise
+	 table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
+      else
+	 -- ratain useful details
+	 table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
+      end
    end
    return batch
 end
 function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
    if style == "art" then
       if level == 1 then
-	 if torch.uniform() > 0.8 then
-	    return pairwise_transform.jpeg_(src, {},
-					    size, offset, n, options)
-	 else
-	    return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
-					    size, offset, n, options)
-	 end
+	 return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
+					 size, offset, n, options)
       elseif level == 2 then
 	 local r = torch.uniform()
-	 if torch.uniform() > 0.9 then
-	    return pairwise_transform.jpeg_(src, {},
+	 if r > 0.6 then
+	    return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
+					    size, offset, n, options)
+	 elseif r > 0.3 then
+	    local quality1 = torch.random(37, 70)
+	    local quality2 = quality1 - torch.random(5, 10)
+	    return pairwise_transform.jpeg_(src, {quality1, quality2},
 					    size, offset, n, options)
 	 else
-	    if r > 0.6 then
-	       return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
-					       size, offset, n, options)
-	    elseif r > 0.3 then
-	       local quality1 = torch.random(37, 70)
-	       local quality2 = quality1 - torch.random(5, 10)
-	       return pairwise_transform.jpeg_(src, {quality1, quality2},
-					       size, offset, n, options)
-	    else
-	       local quality1 = torch.random(52, 70)
-	       local quality2 = quality1 - torch.random(5, 15)
-	       local quality3 = quality1 - torch.random(15, 25)
-	       
-	       return pairwise_transform.jpeg_(src, 
-					       {quality1, quality2, quality3},
-					       size, offset, n, options)
-	    end
+	    local quality1 = torch.random(52, 70)
+	    local quality2 = quality1 - torch.random(5, 15)
+	    local quality3 = quality1 - torch.random(15, 25)
+	    
+	    return pairwise_transform.jpeg_(src, 
+					    {quality1, quality2, quality3},
+					    size, offset, n, options)
 	 end
       else
 	 error("unknown noise level: " .. level)
       end
    elseif style == "photo" then
       if level == 1 then
-	 if torch.uniform() > 0.7 then
-	    return pairwise_transform.jpeg_(src, {},
-					    size, offset, n,
-					    options)
-	 else
-	    return pairwise_transform.jpeg_(src, {torch.random(80, 95)},
-					    size, offset, n,
-					    options)
-	 end
+	 return pairwise_transform.jpeg_(src, {torch.random(80, 95)},
+					 size, offset, n,
+					 options)
       elseif level == 2 then
-	 if torch.uniform() > 0.7 then
-	    return pairwise_transform.jpeg_(src, {},
-					    size, offset, n,
-					    options)
-	 else
-	    return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
-					    size, offset, n,
-					    options)
-	 end
+	 return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
+					 size, offset, n,
+					 options)
       else
 	 error("unknown noise level: " .. level)
       end

+ 1 - 0
lib/settings.lua

@@ -41,6 +41,7 @@ cmd:option("-validation_rate", 0.05, 'validation-set rate of data')
 cmd:option("-validation_crops", 80, 'number of region per image in validation')
 cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
 cmd:option("-active_cropping_tries", 10, 'active cropping tries')
+cmd:option("-nr_rate", 0.7, 'trade-off between reducing noise and erasing details (0.0-1.0)')
 
 local opt = cmd:parse(arg)
 for k, v in pairs(opt) do

+ 1 - 1
train.lua

@@ -89,7 +89,6 @@ local function transformer(x, is_validation, n, offset)
    local overlay = nil
    local active_cropping_rate = nil
    local active_cropping_tries = nil
-   
    if is_validation then
       active_cropping_rate = 0
       active_cropping_tries = 0
@@ -128,6 +127,7 @@ local function transformer(x, is_validation, n, offset)
 				       jpeg_sampling_factors = settings.jpeg_sampling_factors,
 				       active_cropping_rate = active_cropping_rate,
 				       active_cropping_tries = active_cropping_tries,
+				       nr_rate = settings.nr_rate,
 				       rgb = (settings.color == "rgb")
 				     })
    end