|
@@ -101,18 +101,14 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|
|
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
|
|
|
end
|
|
|
|
|
|
-local function create_criterion(model, loss)
|
|
|
+local function create_criterion(model)
|
|
|
if reconstruct.is_rgb(model) then
|
|
|
local offset = reconstruct.offset_size(model)
|
|
|
local output_w = settings.crop_size - offset * 2
|
|
|
local weight = torch.Tensor(3, output_w * output_w)
|
|
|
- if loss == "y" then
|
|
|
- weight[1]:fill(0.29891 * 3) -- R
|
|
|
- weight[2]:fill(0.58661 * 3) -- G
|
|
|
- weight[3]:fill(0.11448 * 3) -- B
|
|
|
- else
|
|
|
- weight:fill(1)
|
|
|
- end
|
|
|
+ weight[1]:fill(0.29891 * 3) -- R
|
|
|
+ weight[2]:fill(0.58661 * 3) -- G
|
|
|
+ weight[3]:fill(0.11448 * 3) -- B
|
|
|
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
|
|
else
|
|
|
local offset = reconstruct.offset_size(model)
|
|
@@ -309,7 +305,7 @@ local function train()
|
|
|
local pairwise_func = function(x, is_validation, n)
|
|
|
return transformer(model, x, is_validation, n, offset)
|
|
|
end
|
|
|
- local criterion = create_criterion(model, settings.loss)
|
|
|
+ local criterion = create_criterion(model)
|
|
|
local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
|
|
|
local x = remove_small_image(torch.load(settings.images))
|
|
|
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|