Преглед изворни кода

Add BCE(binary cross entropy) loss for segmentation

Sigmoid() output is required.
nagadomi пре 8 година
родитељ
комит
cdafbf00ae
2 измењених фајлова са 5 додато и 1 уклоњено
  1. 1 1
      lib/settings.lua
  2. 4 0
      train.lua

+ 1 - 1
lib/settings.lua

@@ -75,7 +75,7 @@ cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate *
 cmd:option("-resume", "", 'resume model file')
 cmd:option("-name", "user", 'model name for user method')
 cmd:option("-gpu", 1, 'Device ID')
-cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
+cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
 cmd:option("-update_criterion", "mse", 'mse|loss')
 
 local function to_bool(settings, name)

+ 4 - 0
train.lua

@@ -322,6 +322,10 @@ local function create_criterion(model)
       return w2nn.L1Criterion():cuda()
    elseif settings.loss == "mse" then
       return w2nn.ClippedMSECriterion(0, 1.0):cuda()
+   elseif settings.loss == "bce" then
+      local bce = nn.BCECriterion()
+      bce.sizeAverage = true
+      return bce:cuda()
    else
       error("unsupported loss .." .. settings.loss)
    end