فهرست منبع

Change the sampling method

nagadomi 9 سال پیش
والد
کامیت
86feb1d4c9
1فایلهای تغییر یافته به همراه16 افزوده شده و 15 حذف شده
  1. 16 15
      lib/pairwise_transform.lua

+ 16 - 15
lib/pairwise_transform.lua

@@ -46,6 +46,10 @@ end
 local function active_cropping(x, y, size, p, tries)
    assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
    local r = torch.uniform()
+   local t = "float"
+   if x:type() == "torch.ByteTensor" then
+      t = "byte"
+   end
    if p < r then
       local xi = torch.random(0, y:size(3) - (size + 1))
       local yi = torch.random(0, y:size(2) - (size + 1))
@@ -53,6 +57,10 @@ local function active_cropping(x, y, size, p, tries)
       local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
       return xc, yc
    else
+      local lowres = gm.Image(x, "RGB", "DHW"):
+	 size(x:size(3) * 0.5, x:size(2) * 0.5, "Box"):
+	 size(x:size(3), x:size(2), "Box"):
+	 toTensor(t, "RGB", "DHW")
       local best_se = 0.0
       local best_xc, best_yc
       local m = torch.FloatTensor(x:size(1), size, size)
@@ -60,13 +68,13 @@ local function active_cropping(x, y, size, p, tries)
 	 local xi = torch.random(0, y:size(3) - (size + 1))
 	 local yi = torch.random(0, y:size(2) - (size + 1))
 	 local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
-	 local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
+	 local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
 	 local xcf = iproc.byte2float(xc)
-	 local ycf = iproc.byte2float(yc)
-	 local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
+	 local lcf = iproc.byte2float(lc)
+	 local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
 	 if se >= best_se then
 	    best_xc = xcf
-	    best_yc = ycf
+	    best_yc = iproc.byte2float(iproc.crop(y, xi, yi, xi + size, yi + size))
 	    best_se = se
 	 end
       end
@@ -199,17 +207,10 @@ function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
 	 error("unknown noise level: " .. level)
       end
    elseif style == "photo" then
-      if level == 1 then
-	 return pairwise_transform.jpeg_(src, {torch.random(70, 90)},
-					 size, offset, n,
-					 options)
-      elseif level == 2 then
-	 return pairwise_transform.jpeg_(src, {torch.random(50, 70)},
-					 size, offset, n,
-					 options)
-      else
-	 error("unknown noise level: " .. level)
-      end
+      -- level adjusting by -nr_rate
+      return pairwise_transform.jpeg_(src, {torch.random(50, 75)},
+				      size, offset, n,
+				      options)
    else
       error("unknown style: " .. style)
    end