瀏覽代碼

Fix missing file

nagadomi 9 年之前
父節點
當前提交
70a2849e39
共有 1 個文件被更改,包括 9 次插入2 次删除
  1. 9 2
      lib/minibatch_adam.lua

+ 9 - 2
lib/minibatch_adam.lua

@@ -19,6 +19,7 @@ local function minibatch_adam(model, criterion, eval_metric,
 				    train_y:size(2)):zero()
    local inputs = inputs_tmp:clone():cuda()
    local targets = targets_tmp:clone():cuda()
+   local instance_loss = torch.Tensor(train_x:size(1)):zero()
 
    print("## update")
    for t = 1, train_x:size(1), batch_size do
@@ -38,7 +39,13 @@ local function minibatch_adam(model, criterion, eval_metric,
 	 gradParameters:zero()
 	 local output = model:forward(inputs)
 	 local f = criterion:forward(output, targets)
-	 sum_eval = sum_eval + eval_metric:forward(output, targets)
+	 local se = 0
+	 for i = 1, batch_size do
+	    local el = eval_metric:forward(output[i], targets[i])
+	    se = se + el
+	    instance_loss[shuffle[t + i - 1]] = el
+	 end
+	 sum_eval = sum_eval + (se / batch_size)
 	 sum_loss = sum_loss + f
 	 count_loss = count_loss + 1
 	 model:backward(inputs, criterion:backward(output, targets))
@@ -52,7 +59,7 @@ local function minibatch_adam(model, criterion, eval_metric,
       end
    end
    xlua.progress(train_x:size(1), train_x:size(1))
-   return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}
+   return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}, instance_loss
 end
 
 return minibatch_adam