|
@@ -78,7 +78,11 @@ local function create_criterion(model)
|
|
weight[3]:fill(0.11448 * 3) -- B
|
|
weight[3]:fill(0.11448 * 3) -- B
|
|
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
|
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
|
else
|
|
else
|
|
- return nn.MSECriterion():cuda()
|
|
|
|
|
|
+ local offset = reconstruct.offset_size(model)
|
|
|
|
+ local output_w = settings.crop_size - offset * 2
|
|
|
|
+ local weight = torch.Tensor(1, output_w * output_w)
|
|
|
|
+ weight[1]:fill(1.0)
|
|
|
|
+ return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
|
end
|
|
end
|
|
end
|
|
end
|
|
local function transformer(x, is_validation, n, offset)
|
|
local function transformer(x, is_validation, n, offset)
|
|
@@ -91,8 +95,8 @@ local function transformer(x, is_validation, n, offset)
|
|
local active_cropping_rate = nil
|
|
local active_cropping_rate = nil
|
|
local active_cropping_tries = nil
|
|
local active_cropping_tries = nil
|
|
if is_validation then
|
|
if is_validation then
|
|
- active_cropping_rate = 0
|
|
|
|
- active_cropping_tries = 0
|
|
|
|
|
|
+ active_cropping_rate = settings.active_cropping_rate
|
|
|
|
+ active_cropping_tries = settings.active_cropping_tries
|
|
random_color_noise_rate = 0.0
|
|
random_color_noise_rate = 0.0
|
|
random_overlay_rate = 0.0
|
|
random_overlay_rate = 0.0
|
|
else
|
|
else
|
|
@@ -108,6 +112,7 @@ local function transformer(x, is_validation, n, offset)
|
|
settings.crop_size, offset,
|
|
settings.crop_size, offset,
|
|
n,
|
|
n,
|
|
{
|
|
{
|
|
|
|
+ downsampling_filters = settings.downsampling_filters,
|
|
random_half_rate = settings.random_half_rate,
|
|
random_half_rate = settings.random_half_rate,
|
|
random_color_noise_rate = random_color_noise_rate,
|
|
random_color_noise_rate = random_color_noise_rate,
|
|
random_overlay_rate = random_overlay_rate,
|
|
random_overlay_rate = random_overlay_rate,
|
|
@@ -153,8 +158,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
-
|
|
|
|
|
|
+local function plot(train, valid)
|
|
|
|
+ gnuplot.plot({
|
|
|
|
+ {'training', torch.Tensor(train), '-'},
|
|
|
|
+ {'validation', torch.Tensor(valid), '-'}})
|
|
|
|
+end
|
|
local function train()
|
|
local function train()
|
|
|
|
+ local hist_train = {}
|
|
|
|
+ local hist_valid = {}
|
|
local LR_MIN = 1.0e-5
|
|
local LR_MIN = 1.0e-5
|
|
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
|
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
|
local offset = reconstruct.offset_size(model)
|
|
local offset = reconstruct.offset_size(model)
|
|
@@ -162,6 +173,7 @@ local function train()
|
|
return transformer(x, is_validation, n, offset)
|
|
return transformer(x, is_validation, n, offset)
|
|
end
|
|
end
|
|
local criterion = create_criterion(model)
|
|
local criterion = create_criterion(model)
|
|
|
|
+ local eval_metric = w2nn.PSNRCriterion():cuda()
|
|
local x = torch.load(settings.images)
|
|
local x = torch.load(settings.images)
|
|
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
|
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
|
local adam_config = {
|
|
local adam_config = {
|
|
@@ -175,7 +187,7 @@ local function train()
|
|
elseif settings.color == "rgb" then
|
|
elseif settings.color == "rgb" then
|
|
ch = 3
|
|
ch = 3
|
|
end
|
|
end
|
|
- local best_score = 100000.0
|
|
|
|
|
|
+ local best_score = 0.0
|
|
print("# make validation-set")
|
|
print("# make validation-set")
|
|
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
|
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
|
settings.validation_crops,
|
|
settings.validation_crops,
|
|
@@ -196,19 +208,24 @@ local function train()
|
|
print("# " .. epoch)
|
|
print("# " .. epoch)
|
|
resampling(x, y, train_x, pairwise_func)
|
|
resampling(x, y, train_x, pairwise_func)
|
|
for i = 1, settings.inner_epoch do
|
|
for i = 1, settings.inner_epoch do
|
|
- print(minibatch_adam(model, criterion, x, y, adam_config))
|
|
|
|
|
|
+ local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
|
|
|
+ print(train_score)
|
|
model:evaluate()
|
|
model:evaluate()
|
|
print("# validation")
|
|
print("# validation")
|
|
- local score = validate(model, criterion, valid_xy)
|
|
|
|
- if score < best_score then
|
|
|
|
|
|
+ local score = validate(model, eval_metric, valid_xy)
|
|
|
|
+
|
|
|
|
+ table.insert(hist_train, train_score.PSNR)
|
|
|
|
+ table.insert(hist_valid, score)
|
|
|
|
+ if settings.plot then
|
|
|
|
+ plot(hist_train, hist_valid)
|
|
|
|
+ end
|
|
|
|
+ if score > best_score then
|
|
local test_image = image_loader.load_float(settings.test) -- reload
|
|
local test_image = image_loader.load_float(settings.test) -- reload
|
|
lrd_count = 0
|
|
lrd_count = 0
|
|
best_score = score
|
|
best_score = score
|
|
print("* update best model")
|
|
print("* update best model")
|
|
if settings.save_history then
|
|
if settings.save_history then
|
|
- local model_clone = model:clone()
|
|
|
|
- w2nn.cleanup_model(model_clone)
|
|
|
|
- torch.save(string.format(settings.model_file, epoch, i), model_clone)
|
|
|
|
|
|
+ torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
|
|
if settings.method == "noise" then
|
|
if settings.method == "noise" then
|
|
local log = path.join(settings.model_dir,
|
|
local log = path.join(settings.model_dir,
|
|
("noise%d_best.%d-%d.png"):format(settings.noise_level,
|
|
("noise%d_best.%d-%d.png"):format(settings.noise_level,
|
|
@@ -221,7 +238,7 @@ local function train()
|
|
save_test_scale(model, test_image, log)
|
|
save_test_scale(model, test_image, log)
|
|
end
|
|
end
|
|
else
|
|
else
|
|
- torch.save(settings.model_file, model)
|
|
|
|
|
|
+ torch.save(settings.model_file, model:clearState(), "ascii")
|
|
if settings.method == "noise" then
|
|
if settings.method == "noise" then
|
|
local log = path.join(settings.model_dir,
|
|
local log = path.join(settings.model_dir,
|
|
("noise%d_best.png"):format(settings.noise_level))
|
|
("noise%d_best.png"):format(settings.noise_level))
|