浏览代码

Use Huber loss instead of MSE

nagadomi 9 年之前
父节点
当前提交
290a5f960b
共有 1 个文件被更改,包括 5 次插入1 次删除
  1. 5 1
      train.lua

+ 5 - 1
train.lua

@@ -78,7 +78,11 @@ local function create_criterion(model)
       weight[3]:fill(0.11448 * 3) -- B
       return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
    else
-      return nn.MSECriterion():cuda()
+      local offset = reconstruct.offset_size(model)
+      local output_w = settings.crop_size - offset * 2
+      local weight = torch.Tensor(1, output_w * output_w)
+      weight[1]:fill(1.0)
+      return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
    end
 end
 local function transformer(x, is_validation, n, offset)