Browse Source

Use MSE instead of PSNR

PSNR depends on the minibatch size and those group.
nagadomi 9 years ago
parent
commit
68a6d4cef5
2 changed files with 6 additions and 8 deletions
  1. 1 2
      lib/minibatch_adam.lua
  2. 5 6
      train.lua

+ 1 - 2
lib/minibatch_adam.lua

@@ -52,8 +52,7 @@ local function minibatch_adam(model, criterion, eval_metric,
       end
    end
    xlua.progress(train_x:size(1), train_x:size(1))
-   
-   return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
+   return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}
 end
 
 return minibatch_adam

+ 5 - 6
train.lua

@@ -198,7 +198,7 @@ local function train()
       return transformer(x, is_validation, n, offset)
    end
    local criterion = create_criterion(model)
-   local eval_metric = w2nn.PSNRCriterion():cuda()
+   local eval_metric = nn.MSECriterion():cuda()
    local x = torch.load(settings.images)
    local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
    local adam_config = {
@@ -212,7 +212,7 @@ local function train()
    elseif settings.color == "rgb" then
       ch = 3
    end
-   local best_score = 0.0
+   local best_score = 1000.0
    print("# make validation-set")
    local valid_xy = make_validation_set(valid_x, pairwise_func,
 					settings.validation_crops,
@@ -227,7 +227,6 @@ local function train()
 			  ch, settings.crop_size, settings.crop_size)
    local y = torch.Tensor(settings.patches * #train_x,
 			  ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
-
    for epoch = 1, settings.epoch do
       model:training()
       print("# " .. epoch)
@@ -238,12 +237,12 @@ local function train()
 	 model:evaluate()
 	 print("# validation")
 	 local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
-	 table.insert(hist_train, train_score.PSNR)
+	 table.insert(hist_train, train_score.MSE)
 	 table.insert(hist_valid, score)
 	 if settings.plot then
 	    plot(hist_train, hist_valid)
 	 end
-	 if score > best_score then
+	 if score < best_score then
 	    local test_image = image_loader.load_float(settings.test) -- reload
 	    lrd_count = 0
 	    best_score = score
@@ -281,7 +280,7 @@ local function train()
 	       lrd_count = 0
 	    end
 	 end
-	 print("current: " .. score .. ", best: " .. best_score)
+	 print("PSNR: " .. 10 * math.log10(1 / score) .. ", MSE: " .. score .. ", Best MSE: " .. best_score)
 	 collectgarbage()
       end
    end