Parcourir la source

Add support for grayscale data

nagadomi il y a 8 ans
Parent
commit
f0fc2c89d1

+ 7 - 0
convert_data.lua

@@ -123,6 +123,10 @@ local function load_images(list)
 		  end
 		  end
 		  xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
 		  xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
 		  xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
 		  xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
+		  if settings.grayscale then
+		     xx = iproc.rgb2y(xx)
+		     yy = iproc.rgb2y(yy)
+		  end
 		  table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
 		  table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
 				  {data = {filters = filters, has_x = true}}})
 				  {data = {filters = filters, has_x = true}}})
 	       else
 	       else
@@ -137,6 +141,9 @@ local function load_images(list)
 		  scale = 2.0
 		  scale = 2.0
 	       end
 	       end
 	       if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
 	       if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
+		  if settings.grayscale then
+		     im = iproc.rgb2y(im)
+		  end
 		  table.insert(x, {compression.compress(im), {data = {filters = filters}}})
 		  table.insert(x, {compression.compress(im), {data = {filters = filters}}})
 	       else
 	       else
 		  io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
 		  io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))

+ 1 - 0
lib/iproc.lua

@@ -229,6 +229,7 @@ function iproc.rgb2y(src)
    src, conversion = iproc.byte2float(src)
    src, conversion = iproc.byte2float(src)
    local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
    local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
    dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
    dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
+   dest:clamp(0, 1)
    if conversion then
    if conversion then
       dest = iproc.float2byte(dest)
       dest = iproc.float2byte(dest)
    end
    end

+ 4 - 2
lib/pairwise_transform_jpeg.lua

@@ -43,8 +43,10 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
       yc = iproc.byte2float(yc)
       yc = iproc.byte2float(yc)
       if options.rgb then
       if options.rgb then
       else
       else
-	 yc = iproc.rgb2y(yc)
-	 xc = iproc.rgb2y(xc)
+	 if xc:size(1) > 1 then
+	    yc = iproc.rgb2y(yc)
+	    xc = iproc.rgb2y(xc)
+	 end
       end
       end
       if torch.uniform() < options.nr_rate then
       if torch.uniform() < options.nr_rate then
 	 -- reducing noise
 	 -- reducing noise

+ 4 - 2
lib/pairwise_transform_scale.lua

@@ -51,8 +51,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
       yc = iproc.byte2float(yc)
       yc = iproc.byte2float(yc)
       if options.rgb then
       if options.rgb then
       else
       else
-	 yc = iproc.rgb2y(yc)
-	 xc = iproc.rgb2y(xc)
+	 if xc:size(1) > 1 then
+	    yc = iproc.rgb2y(yc)
+	    xc = iproc.rgb2y(xc)
+	 end
       end
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end
    end

+ 4 - 2
lib/pairwise_transform_user.lua

@@ -38,8 +38,10 @@ function pairwise_transform.user(x, y, size, offset, n, options)
       yc = iproc.byte2float(yc)
       yc = iproc.byte2float(yc)
       if options.rgb then
       if options.rgb then
       else
       else
-	 yc = iproc.rgb2y(yc)
-	 xc = iproc.rgb2y(xc)
+	 if xc:size(1) > 1 then
+	    yc = iproc.rgb2y(yc)
+	    xc = iproc.rgb2y(xc)
+	 end
       end
       end
       if options.gcn then
       if options.gcn then
 	 local mean = xc:mean()
 	 local mean = xc:mean()

+ 11 - 4
lib/pairwise_transform_utils.lua

@@ -279,10 +279,17 @@ function pairwise_transform_utils.low_resolution(src)
 	    toTensor("byte", "RGB", "DHW")
 	    toTensor("byte", "RGB", "DHW")
    end
    end
 --]]
 --]]
-   return gm.Image(src, "RGB", "DHW"):
-      size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
-      size(src:size(3), src:size(2), "Box"):
-      toTensor("byte", "RGB", "DHW")
+   if src:size(1) == 1 then
+      return gm.Image(src, "I", "DHW"):
+	 size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+	 size(src:size(3), src:size(2), "Box"):
+	 toTensor("byte", "I", "DHW")
+   else
+      return gm.Image(src, "RGB", "DHW"):
+	 size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+	 size(src:size(3), src:size(2), "Box"):
+	 toTensor("byte", "RGB", "DHW")
+   end
 end
 end
 
 
 return pairwise_transform_utils
 return pairwise_transform_utils

+ 2 - 0
lib/settings.lua

@@ -79,6 +79,7 @@ cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
 cmd:option("-update_criterion", "mse", 'mse|loss')
 cmd:option("-update_criterion", "mse", 'mse|loss')
 cmd:option("-padding", 0, 'replication padding size')
 cmd:option("-padding", 0, 'replication padding size')
 cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
 cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
+cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
 
 
 local function to_bool(settings, name)
 local function to_bool(settings, name)
    if settings[name] == 1 then
    if settings[name] == 1 then
@@ -98,6 +99,7 @@ to_bool(settings, "use_transparent_png")
 to_bool(settings, "pairwise_y_binary")
 to_bool(settings, "pairwise_y_binary")
 to_bool(settings, "pairwise_flip")
 to_bool(settings, "pairwise_flip")
 to_bool(settings, "padding_y_zero")
 to_bool(settings, "padding_y_zero")
+to_bool(settings, "grayscale")
 
 
 if settings.plot then
 if settings.plot then
    require 'gnuplot'
    require 'gnuplot'