ソースを参照

Fix validation metric

nagadomi 9 年 前
コミット
ea780f1871
1 ファイル変更31 行追加14 行削除
  1. 31 14
      train.lua

+ 31 - 14
train.lua

@@ -41,31 +41,49 @@ local function make_validation_set(x, transformer, n, patches)
    for i = 1, #x do
       for k = 1, math.max(n / patches, 1) do
 	 local xy = transformer(x[i], true, patches)
-	 local tx = torch.Tensor(patches, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
-	 local ty = torch.Tensor(patches, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
 	 for j = 1, #xy do
-	    tx[j]:copy(xy[j][1])
-	    ty[j]:copy(xy[j][2])
+	    table.insert(data, {x = xy[j][1], y = xy[j][2]})
 	 end
-	 table.insert(data, {x = tx, y = ty})
       end
       xlua.progress(i, #x)
       collectgarbage()
    end
    return data
 end
-local function validate(model, criterion, data)
+local function validate(model, criterion, data, batch_size)
    local loss = 0
-   for i = 1, #data do
-      local z = model:forward(data[i].x:cuda())
-      loss = loss + criterion:forward(z, data[i].y:cuda())
-      if i % 100 == 0 then
-	 xlua.progress(i, #data)
+   local loss_count = 0
+   local inputs_tmp = torch.Tensor(batch_size,
+				   data[1].x:size(1), 
+				   data[1].x:size(2),
+				   data[1].x:size(3)):zero()
+   local targets_tmp = torch.Tensor(batch_size,
+				    data[1].y:size(1),
+				    data[1].y:size(2),
+				    data[1].y:size(3)):zero()
+   local inputs = inputs_tmp:clone():cuda()
+   local targets = targets_tmp:clone():cuda()
+   
+   for t = 1, #data, batch_size do
+      if t + batch_size -1 > #data then
+	 break
+      end
+      for i = 1, batch_size do
+         inputs_tmp[i]:copy(data[t + i - 1].x)
+	 targets_tmp[i]:copy(data[t + i - 1].y)
+      end
+      inputs:copy(inputs_tmp)
+      targets:copy(targets_tmp)
+      local z = model:forward(inputs)
+      loss = loss + criterion:forward(z, targets)
+      loss_count = loss_count + 1
+      if t % 10 == 0 then
+	 xlua.progress(t, #data)
 	 collectgarbage()
       end
    end
    xlua.progress(#data, #data)
-   return loss / #data
+   return loss / loss_count
 end
 
 local function create_criterion(model)
@@ -214,8 +232,7 @@ local function train()
 	 print(train_score)
 	 model:evaluate()
 	 print("# validation")
-	 local score = validate(model, eval_metric, valid_xy)
-
+	 local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
 	 table.insert(hist_train, train_score.PSNR)
 	 table.insert(hist_valid, score)
 	 if settings.plot then