Browse Source

Use correct criterion

nagadomi 9 years ago
parent
commit
b96bc5d453
1 changed files with 11 additions and 9 deletions
  1. 11 9
      train.lua

+ 11 - 9
train.lua

@@ -58,8 +58,9 @@ local function make_validation_set(x, transformer, n, patches)
    data = new_data
    return data
 end
-local function validate(model, criterion, data, batch_size)
+local function validate(model, criterion, eval_metric, data, batch_size)
    local loss = 0
+   local mse = 0
    local loss_count = 0
    local inputs_tmp = torch.Tensor(batch_size,
 				   data[1].x:size(1), 
@@ -83,6 +84,7 @@ local function validate(model, criterion, data, batch_size)
       targets:copy(targets_tmp)
       local z = model:forward(inputs)
       loss = loss + criterion:forward(z, targets)
+      mse = mse + eval_metric:forward(z, targets)
       loss_count = loss_count + 1
       if loss_count % 10 == 0 then
 	 xlua.progress(t, #data)
@@ -90,7 +92,7 @@ local function validate(model, criterion, data, batch_size)
       end
    end
    xlua.progress(#data, #data)
-   return loss / loss_count
+   return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
 end
 
 local function create_criterion(model)
@@ -247,7 +249,7 @@ local function train()
       return transformer(model, x, is_validation, n, offset)
    end
    local criterion = create_criterion(model)
-   local eval_metric = nn.MSECriterion():cuda()
+   local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
    local x = remove_small_image(torch.load(settings.images))
    local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
    local adam_config = {
@@ -312,16 +314,16 @@ local function train()
 	 print(train_score)
 	 model:evaluate()
 	 print("# validation")
-	 local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
-	 table.insert(hist_train, train_score.MSE)
-	 table.insert(hist_valid, score)
+	 local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
+	 table.insert(hist_train, train_score.loss)
+	 table.insert(hist_valid, score.loss)
 	 if settings.plot then
 	    plot(hist_train, hist_valid)
 	 end
-	 if score < best_score then
+	 if score.loss < best_score then
 	    local test_image = image_loader.load_float(settings.test) -- reload
 	    lrd_count = 0
-	    best_score = score
+	    best_score = score.loss
 	    print("* update best model")
 	    if settings.save_history then
 	       torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
@@ -356,7 +358,7 @@ local function train()
 	       lrd_count = 0
 	    end
 	 end
-	 print("PSNR: " .. 10 * math.log10(1 / score) .. ", MSE: " .. score .. ", Best MSE: " .. best_score)
+	 print("PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score)
 	 collectgarbage()
       end
    end