Browse Source

Add L1 criterion. Change the criterion of updating model

nagadomi 8 years ago
parent
commit
d2cfb8f104
4 changed files with 51 additions and 16 deletions
  1. 27 0
      lib/L1Criterion.lua
  2. 1 0
      lib/settings.lua
  3. 1 0
      lib/w2nn.lua
  4. 22 16
      train.lua

+ 27 - 0
lib/L1Criterion.lua

@@ -0,0 +1,27 @@
+-- ref: https://en.wikipedia.org/wiki/L1_loss
+local L1Criterion, parent = torch.class('w2nn.L1Criterion','nn.Criterion')
+
+function L1Criterion:__init()
+   parent.__init(self)
+   self.diff = torch.Tensor()
+   self.linear_loss_buff = torch.Tensor()
+end
+function L1Criterion:updateOutput(input, target)
+   self.diff:resizeAs(input):copy(input)
+   if input:dim() == 1 then
+      self.diff[1] = input[1] - target
+   else
+      for i = 1, input:size(1) do
+	 self.diff[i]:add(-1, target[i])
+      end
+   end
+   local linear_targets = self.diff
+   local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():sum()
+   self.output = (linear_loss) / input:nElement()
+   return self.output
+end
+function L1Criterion:updateGradInput(input, target)
+   local norm = 1.0 / input:nElement()
+   self.gradInput:resizeAs(self.diff):copy(self.diff):sign():mul(norm)
+   return self.gradInput
+end

+ 1 - 0
lib/settings.lua

@@ -75,6 +75,7 @@ cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate *
 cmd:option("-resume", "", 'resume model file')
 cmd:option("-name", "user", 'model name for user method')
 cmd:option("-gpu", 1, 'Device ID')
+cmd:option("-loss", "huber", 'loss function (huber|l1)')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then

+ 1 - 0
lib/w2nn.lua

@@ -32,5 +32,6 @@ else
    require 'ClippedMSECriterion'
    require 'SSIMCriterion'
    require 'InplaceClip01'
+   require 'L1Criterion'
    return w2nn
 end

+ 22 - 16
train.lua

@@ -298,20 +298,26 @@ local function validate(model, criterion, eval_metric, data, batch_size)
 end
 
 local function create_criterion(model)
-   if reconstruct.is_rgb(model) then
-      local offset = reconstruct.offset_size(model)
-      local output_w = settings.crop_size - offset * 2
-      local weight = torch.Tensor(3, output_w * output_w)
-      weight[1]:fill(0.29891 * 3) -- R
-      weight[2]:fill(0.58661 * 3) -- G
-      weight[3]:fill(0.11448 * 3) -- B
-      return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
+   if settings.loss == "huber" then
+      if reconstruct.is_rgb(model) then
+	 local offset = reconstruct.offset_size(model)
+	 local output_w = settings.crop_size - offset * 2
+	 local weight = torch.Tensor(3, output_w * output_w)
+	 weight[1]:fill(0.29891 * 3) -- R
+	 weight[2]:fill(0.58661 * 3) -- G
+	 weight[3]:fill(0.11448 * 3) -- B
+	 return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
+      else
+	 local offset = reconstruct.offset_size(model)
+	 local output_w = settings.crop_size - offset * 2
+	 local weight = torch.Tensor(1, output_w * output_w)
+	 weight[1]:fill(1.0)
+	 return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
+      end
+   elseif settings.loss == "l1" then
+      return w2nn.L1Criterion():cuda()
    else
-      local offset = reconstruct.offset_size(model)
-      local output_w = settings.crop_size - offset * 2
-      local weight = torch.Tensor(1, output_w * output_w)
-      weight[1]:fill(1.0)
-      return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
+      error("unsupported loss .." .. settings.loss)
    end
 end
 
@@ -518,9 +524,9 @@ local function train()
 	 if settings.plot then
 	    plot(hist_train, hist_valid)
 	 end
-	 if score.MSE < best_score then
+	 if score.loss < best_score then
 	    local test_image = image_loader.load_float(settings.test) -- reload
-	    best_score = score.MSE
+	    best_score = score.loss
 	    print("* model has updated")
 	    if settings.save_history then
 	       torch.save(settings.model_file_best, model:clearState(), "ascii")
@@ -569,7 +575,7 @@ local function train()
 	       end
 	    end
 	 end
-	 print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", Minimum MSE: " .. best_score)
+	 print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
 	 collectgarbage()
       end
    end