|
@@ -157,8 +157,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
|
|
|
end
|
|
|
end
|
|
|
end
|
|
|
-
|
|
|
+local function plot(train, valid)
|
|
|
+ gnuplot.plot({
|
|
|
+ {'training', torch.Tensor(train), '-'},
|
|
|
+ {'validation', torch.Tensor(valid), '-'}})
|
|
|
+end
|
|
|
local function train()
|
|
|
+ local hist_train = {}
|
|
|
+ local hist_valid = {}
|
|
|
local LR_MIN = 1.0e-5
|
|
|
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
|
|
local offset = reconstruct.offset_size(model)
|
|
@@ -201,10 +207,17 @@ local function train()
|
|
|
print("# " .. epoch)
|
|
|
resampling(x, y, train_x, pairwise_func)
|
|
|
for i = 1, settings.inner_epoch do
|
|
|
- print(minibatch_adam(model, criterion, eval_metric, x, y, adam_config))
|
|
|
+ local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
|
|
+ print(train_score)
|
|
|
model:evaluate()
|
|
|
print("# validation")
|
|
|
local score = validate(model, eval_metric, valid_xy)
|
|
|
+
|
|
|
+ table.insert(hist_train, train_score.PSNR)
|
|
|
+ table.insert(hist_valid, score)
|
|
|
+ if settings.plot then
|
|
|
+ plot(hist_train, hist_valid)
|
|
|
+ end
|
|
|
if score > best_score then
|
|
|
local test_image = image_loader.load_float(settings.test) -- reload
|
|
|
lrd_count = 0
|