소스 검색

Minimize the weighted huber loss instead of the weighted mean square error

Huber loss is less sensitive to outliers(i.e. noise) in data than the squared error loss.
nagadomi 9 년 전
부모
커밋
490eb33a6b
2개의 변경된 파일19개의 추가작업 그리고 17개의 파일을 삭제
  1. 1 3
      lib/minibatch_adam.lua
  2. 18 14
      train.lua

+ 1 - 3
lib/minibatch_adam.lua

@@ -21,7 +21,6 @@ local function minibatch_adam(model, criterion,
 			       input_size[1], input_size[2], input_size[3])
    local targets_tmp = torch.Tensor(batch_size,
 				    target_size[1] * target_size[2] * target_size[3])
-   
    for t = 1, #train_x do
       xlua.progress(t, #train_x)
       local xy = transformer(train_x[shuffle[t]], false, batch_size)
@@ -31,7 +30,6 @@ local function minibatch_adam(model, criterion,
       end
       inputs:copy(inputs_tmp)
       targets:copy(targets_tmp)
-      
       local feval = function(x)
 	 if x ~= parameters then
 	    parameters:copy(x)
@@ -53,7 +51,7 @@ local function minibatch_adam(model, criterion,
    end
    xlua.progress(#train_x, #train_x)
    
-   return { mse = sum_loss / count_loss}
+   return { loss = sum_loss / count_loss}
 end
 
 return minibatch_adam

+ 18 - 14
train.lua

@@ -35,18 +35,19 @@ local function split_data(x, test_size)
    end
    return train_x, valid_x
 end
-local function make_validation_set(x, transformer, n)
+local function make_validation_set(x, transformer, n, batch_size)
    n = n or 4
    local data = {}
    for i = 1, #x do
-      for k = 1, math.max(n / 8, 1) do
-	 local xy = transformer(x[i], true, 8)
+      for k = 1, math.max(n / batch_size, 1) do
+	 local xy = transformer(x[i], true, batch_size)
+	 local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
+	 local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
 	 for j = 1, #xy do
-	    local x = xy[j][1]
-	    local y = xy[j][2]
-	    table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
-				y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
+	    tx[j]:copy(xy[j][1])
+	    ty[j]:copy(xy[j][2])
 	 end
+	 table.insert(data, {x = tx, y = ty})
       end
       xlua.progress(i, #x)
       collectgarbage()
@@ -58,11 +59,12 @@ local function validate(model, criterion, data)
    for i = 1, #data do
       local z = model:forward(data[i].x:cuda())
       loss = loss + criterion:forward(z, data[i].y:cuda())
-      xlua.progress(i, #data)
-      if i % 10 == 0 then
+      if i % 100 == 0 then
+	 xlua.progress(i, #data)
 	 collectgarbage()
       end
    end
+   xlua.progress(#data, #data)
    return loss / #data
 end
 
@@ -71,10 +73,10 @@ local function create_criterion(model)
       local offset = reconstruct.offset_size(model)
       local output_w = settings.crop_size - offset * 2
       local weight = torch.Tensor(3, output_w * output_w)
-      weight[1]:fill(0.299 * 3) -- R
-      weight[2]:fill(0.587 * 3) -- G
-      weight[3]:fill(0.114 * 3) -- B
-      return w2nn.WeightedMSECriterion(weight):cuda()
+      weight[1]:fill(0.29891 * 3) -- R
+      weight[2]:fill(0.58661 * 3) -- G
+      weight[3]:fill(0.11448 * 3) -- B
+      return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
    else
       return nn.MSECriterion():cuda()
    end
@@ -151,7 +153,9 @@ local function train()
    end
    local best_score = 100000.0
    print("# make validation-set")
-   local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops)
+   local valid_xy = make_validation_set(valid_x, pairwise_func,
+					settings.validation_crops,
+					settings.batch_size)
    valid_x = nil
    
    collectgarbage()