Bladeren bron

Merge branch 'dev'

nagadomi 8 jaren geleden
bovenliggende
commit
e371955cc7

+ 27 - 1
convert_data.lua

@@ -63,7 +63,24 @@ local function crop_if_large_pair(x, y, max_size)
       return x, y
    end
 end
-
+local function padding_x(x, pad)
+   if pad > 0 then
+      x = iproc.padding(x, pad, pad, pad, pad)
+   end
+   return x
+end
+local function padding_xy(x, y, pad, y_zero)
+   local scale = y:size(2) / x:size(2)
+   if pad > 0 then
+      x = iproc.padding(x, pad, pad, pad, pad)
+      if y_zero then
+	 y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
+      else
+	 y = iproc.padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
+      end
+   end
+   return x, y
+end
 local function load_images(list)
    local MARGIN = 32
    local csv = csvigo.load({path = list, verbose = false, mode = "raw"})
@@ -105,6 +122,11 @@ local function load_images(list)
 		     xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
 		  end
 		  xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
+		  xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
+		  if settings.grayscale then
+		     xx = iproc.rgb2y(xx)
+		     yy = iproc.rgb2y(yy)
+		  end
 		  table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
 				  {data = {filters = filters, has_x = true}}})
 	       else
@@ -113,11 +135,15 @@ local function load_images(list)
 	    else
 	       im = crop_if_large(im, settings.max_training_image_size)
 	       im = iproc.crop_mod4(im)
+	       im = padding_x(im, settings.padding)
 	       local scale = 1.0
 	       if settings.random_half_rate > 0.0 then
 		  scale = 2.0
 	       end
 	       if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
+		  if settings.grayscale then
+		     im = iproc.rgb2y(im)
+		  end
 		  table.insert(x, {compression.compress(im), {data = {filters = filters}}})
 	       else
 		  io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))

+ 42 - 0
lib/ShakeShakeTable.lua

@@ -0,0 +1,42 @@
+local ShakeShakeTable, parent = torch.class('w2nn.ShakeShakeTable','nn.Module')
+
+function ShakeShakeTable:__init()
+   parent.__init(self)
+   self.alpha = torch.Tensor()
+   self.beta = torch.Tensor()
+   self.first = torch.Tensor()
+   self.second = torch.Tensor()
+   self.train = true
+end
+function ShakeShakeTable:updateOutput(input)
+   local batch_size = input[1]:size(1)
+   if self.train then
+      self.alpha:resize(batch_size):uniform()
+      self.beta:resize(batch_size):uniform()
+      self.second:resizeAs(input[1]):copy(input[2])
+      for i = 1, batch_size do
+	 self.second[i]:mul(self.alpha[i])
+      end
+      self.output:resizeAs(input[1]):copy(input[1])
+      for i = 1, batch_size do
+	 self.output[i]:mul(1.0 - self.alpha[i])
+      end
+      self.output:add(self.second):mul(2)
+   else
+      self.output:resizeAs(input[1]):copy(input[1]):add(input[2])
+   end
+   return self.output
+end
+function ShakeShakeTable:updateGradInput(input, gradOutput)
+   local batch_size = input[1]:size(1)
+   self.first:resizeAs(gradOutput):copy(gradOutput)
+   for i = 1, batch_size do
+      self.first[i]:mul(self.beta[i])
+   end
+   self.second:resizeAs(gradOutput):copy(gradOutput)
+   for i = 1, batch_size do
+      self.second[i]:mul(1.0 - self.beta[i])
+   end
+   self.gradOutput = {self.first, self.second}
+   return self.gradOutput
+end

+ 15 - 2
lib/iproc.lua

@@ -80,6 +80,8 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
    return dest
 end
 function iproc.padding(img, w1, w2, h1, h2)
+   local conversion
+   img, conversion = iproc.byte2float(img)
    image = image or require 'image'
    local dst_height = img:size(2) + h1 + h2
    local dst_width = img:size(3) + w1 + w2
@@ -88,9 +90,15 @@ function iproc.padding(img, w1, w2, h1, h2)
    flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
    flow[1]:add(-h1)
    flow[2]:add(-w1)
-   return image.warp(img, flow, "simple", false, "clamp")
+   local dest = image.warp(img, flow, "simple", false, "clamp")
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
 end
 function iproc.zero_padding(img, w1, w2, h1, h2)
+   local conversion
+   img, conversion = iproc.byte2float(img)
    image = image or require 'image'
    local dst_height = img:size(2) + h1 + h2
    local dst_width = img:size(3) + w1 + w2
@@ -99,7 +107,11 @@ function iproc.zero_padding(img, w1, w2, h1, h2)
    flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
    flow[1]:add(-h1)
    flow[2]:add(-w1)
-   return image.warp(img, flow, "simple", false, "pad", 0)
+   local dest = image.warp(img, flow, "simple", false, "pad", 0)
+   if conversion then
+      dest = iproc.float2byte(dest)
+   end
+   return dest
 end
 function iproc.white_noise(src, std, rgb_weights, gamma)
    gamma = gamma or 0.454545
@@ -217,6 +229,7 @@ function iproc.rgb2y(src)
    src, conversion = iproc.byte2float(src)
    local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
    dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
+   dest:clamp(0, 1)
    if conversion then
       dest = iproc.float2byte(dest)
    end

+ 4 - 2
lib/pairwise_transform_jpeg.lua

@@ -43,8 +43,10 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = iproc.rgb2y(yc)
-	 xc = iproc.rgb2y(xc)
+	 if xc:size(1) > 1 then
+	    yc = iproc.rgb2y(yc)
+	    xc = iproc.rgb2y(xc)
+	 end
       end
       if torch.uniform() < options.nr_rate then
 	 -- reducing noise

+ 4 - 2
lib/pairwise_transform_scale.lua

@@ -51,8 +51,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = iproc.rgb2y(yc)
-	 xc = iproc.rgb2y(xc)
+	 if xc:size(1) > 1 then
+	    yc = iproc.rgb2y(yc)
+	    xc = iproc.rgb2y(xc)
+	 end
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end

+ 4 - 2
lib/pairwise_transform_user.lua

@@ -38,8 +38,10 @@ function pairwise_transform.user(x, y, size, offset, n, options)
       yc = iproc.byte2float(yc)
       if options.rgb then
       else
-	 yc = iproc.rgb2y(yc)
-	 xc = iproc.rgb2y(xc)
+	 if xc:size(1) > 1 then
+	    yc = iproc.rgb2y(yc)
+	    xc = iproc.rgb2y(xc)
+	 end
       end
       if options.gcn then
 	 local mean = xc:mean()

+ 11 - 4
lib/pairwise_transform_utils.lua

@@ -279,10 +279,17 @@ function pairwise_transform_utils.low_resolution(src)
 	    toTensor("byte", "RGB", "DHW")
    end
 --]]
-   return gm.Image(src, "RGB", "DHW"):
-      size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
-      size(src:size(3), src:size(2), "Box"):
-      toTensor("byte", "RGB", "DHW")
+   if src:size(1) == 1 then
+      return gm.Image(src, "I", "DHW"):
+	 size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+	 size(src:size(3), src:size(2), "Box"):
+	 toTensor("byte", "I", "DHW")
+   else
+      return gm.Image(src, "RGB", "DHW"):
+	 size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+	 size(src:size(3), src:size(2), "Box"):
+	 toTensor("byte", "RGB", "DHW")
+   end
 end
 
 return pairwise_transform_utils

+ 5 - 0
lib/settings.lua

@@ -76,6 +76,9 @@ cmd:option("-name", "user", 'model name for user method')
 cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
 cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
 cmd:option("-update_criterion", "mse", 'mse|loss')
+cmd:option("-padding", 0, 'replication padding size')
+cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
+cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -94,6 +97,8 @@ to_bool(settings, "save_history")
 to_bool(settings, "use_transparent_png")
 to_bool(settings, "pairwise_y_binary")
 to_bool(settings, "pairwise_flip")
+to_bool(settings, "padding_y_zero")
+to_bool(settings, "grayscale")
 
 if settings.plot then
    require 'gnuplot'

+ 1 - 0
lib/w2nn.lua

@@ -74,5 +74,6 @@ else
    require 'SSIMCriterion'
    require 'InplaceClip01'
    require 'L1Criterion'
+   require 'ShakeShakeTable'
    return w2nn
 end