Browse Source

Use roundf-like clip for 8 bit-depth image

Maybe PSNR +0.03 improved by this commit
nagadomi 9 years ago
parent
commit
797b45ae23
4 changed files with 33 additions and 10 deletions
  1. 7 4
      lib/ClippedWeightedHuberCriterion.lua
  2. 24 4
      lib/image_loader.lua
  3. 1 1
      lib/w2nn.lua
  4. 1 1
      train.lua

+ 7 - 4
lib/WeightedHuberCriterion.lua → lib/ClippedWeightedHuberCriterion.lua

@@ -1,8 +1,9 @@
 -- ref: https://en.wikipedia.org/wiki/Huber_loss
-local WeightedHuberCriterion, parent = torch.class('w2nn.WeightedHuberCriterion','nn.Criterion')
+local ClippedWeightedHuberCriterion, parent = torch.class('w2nn.ClippedWeightedHuberCriterion','nn.Criterion')
 
-function WeightedHuberCriterion:__init(w, gamma)
+function ClippedWeightedHuberCriterion:__init(w, gamma, clip)
    parent.__init(self)
+   self.clip = clip
    self.gamma = gamma or 1.0
    self.weight = w:clone()
    self.diff = torch.Tensor()
@@ -11,8 +12,10 @@ function WeightedHuberCriterion:__init(w, gamma)
    self.square_loss_buff = torch.Tensor()
    self.linear_loss_buff = torch.Tensor()
 end
-function WeightedHuberCriterion:updateOutput(input, target)
+function ClippedWeightedHuberCriterion:updateOutput(input, target)
    self.diff:resizeAs(input):copy(input)
+   self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1]
+   self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2]
    for i = 1, input:size(1) do
       self.diff[i]:add(-1, target[i]):cmul(self.weight)
    end
@@ -27,7 +30,7 @@ function WeightedHuberCriterion:updateOutput(input, target)
    self.output = (square_loss + linear_loss) / input:nElement()
    return self.output
 end
-function WeightedHuberCriterion:updateGradInput(input, target)
+function ClippedWeightedHuberCriterion:updateGradInput(input, target)
    local norm = 1.0 / input:nElement()
    self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
    local outlier = torch.ge(self.diff_abs, self.gamma)

+ 24 - 4
lib/image_loader.lua

@@ -4,6 +4,9 @@ require 'pl'
 
 local image_loader = {}
 
+local clip_eta8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
+local clip_eta16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
+
 function image_loader.decode_float(blob)
    local im, alpha = image_loader.decode_byte(blob)
    if im then
@@ -25,13 +28,30 @@ function image_loader.encode_png(rgb, alpha, depth)
       rgba[2]:copy(rgb[2])
       rgba[3]:copy(rgb[3])
       rgba[4]:copy(alpha)
+      
+      if depth < 16 then
+	 rgba:add(clip_eta8)
+	 rgba[torch.lt(rgba, 0.0)] = 0.0
+	 rgba[torch.gt(rgba, 1.0)] = 1.0
+      else
+	 rgba:add(clip_eta16)
+	 rgba[torch.lt(rgba, 0.0)] = 0.0
+	 rgba[torch.gt(rgba, 1.0)] = 1.0
+      end
       local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
-      im:format("png")
-      return im:depth(depth):toBlob(9)
+      return im:depth(depth):format("PNG"):toBlob(9)
    else
+      if depth < 16 then
+	 rgb = rgb:clone():add(clip_eta8)
+	 rgb[torch.lt(rgb, 0.0)] = 0.0
+	 rgb[torch.gt(rgb, 1.0)] = 1.0
+      else
+	 rgb = rgb:clone():add(clip_eta16)
+	 rgb[torch.lt(rgb, 0.0)] = 0.0
+	 rgb[torch.gt(rgb, 1.0)] = 1.0
+      end
       local im = gm.Image(rgb, "RGB", "DHW")
-      im:format("png")
-      return im:depth(depth):toBlob(9)
+      return im:depth(depth):format("PNG"):toBlob(9)
    end
 end
 function image_loader.save_png(filename, rgb, alpha, depth)

+ 1 - 1
lib/w2nn.lua

@@ -20,7 +20,7 @@ else
    require 'LeakyReLU_deprecated'
    require 'DepthExpand2x'
    require 'WeightedMSECriterion'
-   require 'WeightedHuberCriterion'
+   require 'ClippedWeightedHuberCriterion'
    require 'cleanup_model'
    return w2nn
 end

+ 1 - 1
train.lua

@@ -76,7 +76,7 @@ local function create_criterion(model)
       weight[1]:fill(0.29891 * 3) -- R
       weight[2]:fill(0.58661 * 3) -- G
       weight[3]:fill(0.11448 * 3) -- B
-      return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
+      return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
    else
       return nn.MSECriterion():cuda()
    end