Bladeren bron

Merge branch 'dev' of into dev

nagadomi 8 jaren geleden
bovenliggende
commit
8b5ccbed08
3 gewijzigde bestanden met toevoegingen van 18 en 8 verwijderingen
  1. 3 3
      lib/minibatch_adam.lua
  2. 1 0
      lib/settings.lua
  3. 14 5
      train.lua

+ 3 - 3
lib/minibatch_adam.lua

@@ -11,7 +11,7 @@ local function minibatch_adam(model, criterion, eval_metric,
       config.xEvalCount = 0
       config.learningRate = config.xLearningRate
    end
-
+   local sum_psnr = 0
    local sum_loss = 0
    local sum_eval = 0
    local count_loss = 0
@@ -55,6 +55,7 @@ local function minibatch_adam(model, criterion, eval_metric,
 	 else
 	    se = eval_metric:forward(output, targets)
 	 end
+	 sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6)))
 	 sum_eval = sum_eval + se
 	 sum_loss = sum_loss + f
 	 count_loss = count_loss + 1
@@ -69,10 +70,9 @@ local function minibatch_adam(model, criterion, eval_metric,
 	 collectgarbage()
 	 xlua.progress(t, train_x:size(1))
       end
-
    end
    xlua.progress(train_x:size(1), train_x:size(1))
-   return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}, instance_loss
+   return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = sum_psnr / count_loss}, instance_loss
 end
 
 return minibatch_adam

+ 1 - 0
lib/settings.lua

@@ -76,6 +76,7 @@ cmd:option("-resume", "", 'resume model file')
 cmd:option("-name", "user", 'model name for user method')
 cmd:option("-gpu", 1, 'Device ID')
 cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
+cmd:option("-update_criterion", "mse", 'mse|loss')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then

+ 14 - 5
train.lua

@@ -262,6 +262,7 @@ local function make_validation_set(x, n, patches)
    return data
 end
 local function validate(model, criterion, eval_metric, data, batch_size)
+   local psnr = 0
    local loss = 0
    local mse = 0
    local loss_count = 0
@@ -286,8 +287,10 @@ local function validate(model, criterion, eval_metric, data, batch_size)
       inputs:copy(inputs_tmp)
       targets:copy(targets_tmp)
       local z = model:forward(inputs)
+      local batch_mse = eval_metric:forward(z, targets)
       loss = loss + criterion:forward(z, targets)
-      mse = mse + eval_metric:forward(z, targets)
+      mse = mse + batch_mse
+      psnr = psnr + (10 * math.log10(1 / batch_mse))
       loss_count = loss_count + 1
       if loss_count % 10 == 0 then
 	 xlua.progress(t, #data)
@@ -295,7 +298,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
       end
    end
    xlua.progress(#data, #data)
-   return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
+   return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = psnr / loss_count}
 end
 
 local function create_criterion(model)
@@ -540,9 +543,15 @@ local function train()
 	 if settings.plot then
 	    plot(hist_train, hist_valid)
 	 end
-	 if score.loss < best_score then
+	 local score_for_update
+	 if settings.update_criterion == "mse" then
+	    score_for_update = score.MSE
+	 else
+	    score_for_update = score.loss
+	 end
+	 if score_for_update < best_score then
 	    local test_image = image_loader.load_float(settings.test) -- reload
-	    best_score = score.loss
+	    best_score = score_for_update
 	    print("* model has updated")
 	    if settings.save_history then
 	       torch.save(settings.model_file_best, model:clearState(), "ascii")
@@ -591,7 +600,7 @@ local function train()
 	       end
 	    end
 	 end
-	 print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE)
+	 print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", best: " .. best_score)
 	 collectgarbage()
       end
    end