nagadomi 9 лет назад
Родитель
Сommit
81df729a8a
2 измененных файлов с 5 добавлено и 10 удалено
  1. 0 1
      lib/settings.lua
  2. 5 9
      train.lua

+ 0 - 1
lib/settings.lua

@@ -57,7 +57,6 @@ cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage')
 cmd:option("-oracle_rate", 0.1, '')
 cmd:option("-oracle_drop_rate", 0.5, '')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
-cmd:option("-loss", "y", 'loss (rgb|y)')
 cmd:option("-resume", "", 'resume model file')
 cmd:option("-name", "user", 'model name for user method')
 

+ 5 - 9
train.lua

@@ -101,18 +101,14 @@ local function validate(model, criterion, eval_metric, data, batch_size)
    return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
 end
 
-local function create_criterion(model, loss)
+local function create_criterion(model)
    if reconstruct.is_rgb(model) then
       local offset = reconstruct.offset_size(model)
       local output_w = settings.crop_size - offset * 2
       local weight = torch.Tensor(3, output_w * output_w)
-      if loss == "y" then
-	 weight[1]:fill(0.29891 * 3) -- R
-	 weight[2]:fill(0.58661 * 3) -- G
-	 weight[3]:fill(0.11448 * 3) -- B
-      else
-	 weight:fill(1)
-      end
+      weight[1]:fill(0.29891 * 3) -- R
+      weight[2]:fill(0.58661 * 3) -- G
+      weight[3]:fill(0.11448 * 3) -- B
       return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
    else
       local offset = reconstruct.offset_size(model)
@@ -309,7 +305,7 @@ local function train()
    local pairwise_func = function(x, is_validation, n)
       return transformer(model, x, is_validation, n, offset)
    end
-   local criterion = create_criterion(model, settings.loss)
+   local criterion = create_criterion(model)
    local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
    local x = remove_small_image(torch.load(settings.images))
    local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))