Browse Source

Use PSNR for evaluation

nagadomi 9 years ago
parent
commit
1900ac7500
4 changed files with 29 additions and 7 deletions
  1. 19 0
      lib/PSNRCriterion.lua
  2. 4 2
      lib/minibatch_adam.lua
  3. 1 1
      lib/w2nn.lua
  4. 5 4
      train.lua

+ 19 - 0
lib/PSNRCriterion.lua

@@ -0,0 +1,19 @@
+local PSNRCriterion, parent = torch.class('w2nn.PSNRCriterion','nn.Criterion')
+
+function PSNRCriterion:__init()
+   parent.__init(self)
+   self.image = torch.Tensor()
+   self.diff = torch.Tensor()
+end
+function PSNRCriterion:updateOutput(input, target)
+   self.image:resizeAs(input):copy(input)
+   self.image:clamp(0.0, 1.0)
+   self.diff:resizeAs(self.image):copy(self.image)
+   
+   local mse = self.diff:add(-1, target):pow(2):mean()
+   self.output = 10 * math.log10(1.0 / mse)
+   return self.output
+end
+function PSNRCriterion:updateGradInput(input, target)
+   error("PSNRCriterion does not support backward")
+end

+ 4 - 2
lib/minibatch_adam.lua

@@ -2,12 +2,13 @@ require 'optim'
 require 'cutorch'
 require 'xlua'
 
-local function minibatch_adam(model, criterion,
+local function minibatch_adam(model, criterion, eval_metric,
 			      train_x, train_y,
 			      config)
    local parameters, gradParameters = model:getParameters()
    config = config or {}
    local sum_loss = 0
+   local sum_eval = 0
    local count_loss = 0
    local batch_size = config.xBatchSize or 32
    local shuffle = torch.randperm(train_x:size(1))
@@ -39,6 +40,7 @@ local function minibatch_adam(model, criterion,
 	 gradParameters:zero()
 	 local output = model:forward(inputs)
 	 local f = criterion:forward(output, targets)
+	 sum_eval = sum_eval + eval_metric:forward(output, targets)
 	 sum_loss = sum_loss + f
 	 count_loss = count_loss + 1
 	 model:backward(inputs, criterion:backward(output, targets))
@@ -52,7 +54,7 @@ local function minibatch_adam(model, criterion,
    end
    xlua.progress(train_x:size(1), train_x:size(1))
    
-   return { loss = sum_loss / count_loss}
+   return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
 end
 
 return minibatch_adam

+ 1 - 1
lib/w2nn.lua

@@ -19,7 +19,7 @@ else
    require 'LeakyReLU'
    require 'LeakyReLU_deprecated'
    require 'DepthExpand2x'
-   require 'WeightedMSECriterion'
+   require 'PSNRCriterion'
    require 'ClippedWeightedHuberCriterion'
    return w2nn
 end

+ 5 - 4
train.lua

@@ -166,6 +166,7 @@ local function train()
       return transformer(x, is_validation, n, offset)
    end
    local criterion = create_criterion(model)
+   local eval_metric = w2nn.PSNRCriterion():cuda()
    local x = torch.load(settings.images)
    local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
    local adam_config = {
@@ -179,7 +180,7 @@ local function train()
    elseif settings.color == "rgb" then
       ch = 3
    end
-   local best_score = 100000.0
+   local best_score = 0.0
    print("# make validation-set")
    local valid_xy = make_validation_set(valid_x, pairwise_func,
 					settings.validation_crops,
@@ -200,11 +201,11 @@ local function train()
       print("# " .. epoch)
       resampling(x, y, train_x, pairwise_func)
       for i = 1, settings.inner_epoch do
-	 print(minibatch_adam(model, criterion, x, y, adam_config))
+	 print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config))
 	 model:evaluate()
 	 print("# validation")
-	 local score = validate(model, criterion, valid_xy)
-	 if score < best_score then
+	 local score = validate(model, eval_metric, valid_xy)
+	 if score > best_score then
 	    local test_image = image_loader.load_float(settings.test) -- reload
 	    lrd_count = 0
 	    best_score = score