Pārlūkot izejas kodu

Minimize the weighted huber loss instead of the weighted mean square error

Huber loss is less sensitive to outliers(i.e. noise) in data than the squared error loss.
nagadomi 9 gadi atpakaļ
vecāks
revīzija
490eb33a6b
2 mainītis faili ar 19 papildinājumiem un 17 dzēšanām
  1. 1 3
      lib/minibatch_adam.lua
  2. 18 14
      train.lua

+ 1 - 3
lib/minibatch_adam.lua

@@ -21,7 +21,6 @@ local function minibatch_adam(model, criterion,
 			       input_size[1], input_size[2], input_size[3])
    local targets_tmp = torch.Tensor(batch_size,
 				    target_size[1] * target_size[2] * target_size[3])
-   
    for t = 1, #train_x do
       xlua.progress(t, #train_x)
       local xy = transformer(train_x[shuffle[t]], false, batch_size)
@@ -31,7 +30,6 @@ local function minibatch_adam(model, criterion,
       end
       inputs:copy(inputs_tmp)
       targets:copy(targets_tmp)
-      
       local feval = function(x)
 	 if x ~= parameters then
 	    parameters:copy(x)
@@ -53,7 +51,7 @@ local function minibatch_adam(model, criterion,
    end
    xlua.progress(#train_x, #train_x)
    
-   return { mse = sum_loss / count_loss}
+   return { loss = sum_loss / count_loss}
 end
 
 return minibatch_adam

+ 18 - 14
train.lua

@@ -35,18 +35,19 @@ local function split_data(x, test_size)
    end
    return train_x, valid_x
 end
-local function make_validation_set(x, transformer, n)
+local function make_validation_set(x, transformer, n, batch_size)
    n = n or 4
    local data = {}
    for i = 1, #x do
-      for k = 1, math.max(n / 8, 1) do
-	 local xy = transformer(x[i], true, 8)
+      for k = 1, math.max(n / batch_size, 1) do
+	 local xy = transformer(x[i], true, batch_size)
+	 local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
+	 local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
 	 for j = 1, #xy do
-	    local x = xy[j][1]
-	    local y = xy[j][2]
-	    table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
-				y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
+	    tx[j]:copy(xy[j][1])
+	    ty[j]:copy(xy[j][2])
 	 end
+	 table.insert(data, {x = tx, y = ty})
       end
       xlua.progress(i, #x)
       collectgarbage()
@@ -58,11 +59,12 @@ local function validate(model, criterion, data)
    for i = 1, #data do
       local z = model:forward(data[i].x:cuda())
       loss = loss + criterion:forward(z, data[i].y:cuda())
-      xlua.progress(i, #data)
-      if i % 10 == 0 then
+      if i % 100 == 0 then
+	 xlua.progress(i, #data)
 	 collectgarbage()
       end
    end
+   xlua.progress(#data, #data)
    return loss / #data
 end
 
@@ -71,10 +73,10 @@ local function create_criterion(model)
       local offset = reconstruct.offset_size(model)
       local output_w = settings.crop_size - offset * 2
       local weight = torch.Tensor(3, output_w * output_w)
-      weight[1]:fill(0.299 * 3) -- R
-      weight[2]:fill(0.587 * 3) -- G
-      weight[3]:fill(0.114 * 3) -- B
-      return w2nn.WeightedMSECriterion(weight):cuda()
+      weight[1]:fill(0.29891 * 3) -- R
+      weight[2]:fill(0.58661 * 3) -- G
+      weight[3]:fill(0.11448 * 3) -- B
+      return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
    else
       return nn.MSECriterion():cuda()
    end
@@ -151,7 +153,9 @@ local function train()
    end
    local best_score = 100000.0
    print("# make validation-set")
-   local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops)
+   local valid_xy = make_validation_set(valid_x, pairwise_func,
+					settings.validation_crops,
+					settings.batch_size)
    valid_x = nil
    
    collectgarbage()