Ver Fonte

Minimize the weighted huber loss instead of the weighted mean square error

Huber loss is less sensitive to outliers(i.e. noise) in data than the squared error loss.
nagadomi há 9 anos atrás
pai
commit
490eb33a6b
2 ficheiros alterados com 19 adições e 17 exclusões
  1. 1 3
      lib/minibatch_adam.lua
  2. 18 14
      train.lua

+ 1 - 3
lib/minibatch_adam.lua

@@ -21,7 +21,6 @@ local function minibatch_adam(model, criterion,
 			       input_size[1], input_size[2], input_size[3])
 			       input_size[1], input_size[2], input_size[3])
    local targets_tmp = torch.Tensor(batch_size,
    local targets_tmp = torch.Tensor(batch_size,
 				    target_size[1] * target_size[2] * target_size[3])
 				    target_size[1] * target_size[2] * target_size[3])
-   
    for t = 1, #train_x do
    for t = 1, #train_x do
       xlua.progress(t, #train_x)
       xlua.progress(t, #train_x)
       local xy = transformer(train_x[shuffle[t]], false, batch_size)
       local xy = transformer(train_x[shuffle[t]], false, batch_size)
@@ -31,7 +30,6 @@ local function minibatch_adam(model, criterion,
       end
       end
       inputs:copy(inputs_tmp)
       inputs:copy(inputs_tmp)
       targets:copy(targets_tmp)
       targets:copy(targets_tmp)
-      
       local feval = function(x)
       local feval = function(x)
 	 if x ~= parameters then
 	 if x ~= parameters then
 	    parameters:copy(x)
 	    parameters:copy(x)
@@ -53,7 +51,7 @@ local function minibatch_adam(model, criterion,
    end
    end
    xlua.progress(#train_x, #train_x)
    xlua.progress(#train_x, #train_x)
    
    
-   return { mse = sum_loss / count_loss}
+   return { loss = sum_loss / count_loss}
 end
 end
 
 
 return minibatch_adam
 return minibatch_adam

+ 18 - 14
train.lua

@@ -35,18 +35,19 @@ local function split_data(x, test_size)
    end
    end
    return train_x, valid_x
    return train_x, valid_x
 end
 end
-local function make_validation_set(x, transformer, n)
+local function make_validation_set(x, transformer, n, batch_size)
    n = n or 4
    n = n or 4
    local data = {}
    local data = {}
    for i = 1, #x do
    for i = 1, #x do
-      for k = 1, math.max(n / 8, 1) do
-	 local xy = transformer(x[i], true, 8)
+      for k = 1, math.max(n / batch_size, 1) do
+	 local xy = transformer(x[i], true, batch_size)
+	 local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
+	 local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
 	 for j = 1, #xy do
 	 for j = 1, #xy do
-	    local x = xy[j][1]
-	    local y = xy[j][2]
-	    table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
-				y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
+	    tx[j]:copy(xy[j][1])
+	    ty[j]:copy(xy[j][2])
 	 end
 	 end
+	 table.insert(data, {x = tx, y = ty})
       end
       end
       xlua.progress(i, #x)
       xlua.progress(i, #x)
       collectgarbage()
       collectgarbage()
@@ -58,11 +59,12 @@ local function validate(model, criterion, data)
    for i = 1, #data do
    for i = 1, #data do
       local z = model:forward(data[i].x:cuda())
       local z = model:forward(data[i].x:cuda())
       loss = loss + criterion:forward(z, data[i].y:cuda())
       loss = loss + criterion:forward(z, data[i].y:cuda())
-      xlua.progress(i, #data)
-      if i % 10 == 0 then
+      if i % 100 == 0 then
+	 xlua.progress(i, #data)
 	 collectgarbage()
 	 collectgarbage()
       end
       end
    end
    end
+   xlua.progress(#data, #data)
    return loss / #data
    return loss / #data
 end
 end
 
 
@@ -71,10 +73,10 @@ local function create_criterion(model)
       local offset = reconstruct.offset_size(model)
       local offset = reconstruct.offset_size(model)
       local output_w = settings.crop_size - offset * 2
       local output_w = settings.crop_size - offset * 2
       local weight = torch.Tensor(3, output_w * output_w)
       local weight = torch.Tensor(3, output_w * output_w)
-      weight[1]:fill(0.299 * 3) -- R
-      weight[2]:fill(0.587 * 3) -- G
-      weight[3]:fill(0.114 * 3) -- B
-      return w2nn.WeightedMSECriterion(weight):cuda()
+      weight[1]:fill(0.29891 * 3) -- R
+      weight[2]:fill(0.58661 * 3) -- G
+      weight[3]:fill(0.11448 * 3) -- B
+      return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
    else
    else
       return nn.MSECriterion():cuda()
       return nn.MSECriterion():cuda()
    end
    end
@@ -151,7 +153,9 @@ local function train()
    end
    end
    local best_score = 100000.0
    local best_score = 100000.0
    print("# make validation-set")
    print("# make validation-set")
-   local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops)
+   local valid_xy = make_validation_set(valid_x, pairwise_func,
+					settings.validation_crops,
+					settings.batch_size)
    valid_x = nil
    valid_x = nil
    
    
    collectgarbage()
    collectgarbage()