|
@@ -5,7 +5,7 @@ require 'xlua'
|
|
|
require 'pl'
|
|
|
|
|
|
local settings = require './lib/settings'
|
|
|
-local minibatch_sgd = require './lib/minibatch_sgd'
|
|
|
+local minibatch_adam = require './lib/minibatch_adam'
|
|
|
local iproc = require './lib/iproc'
|
|
|
local create_model = require './lib/srcnn'
|
|
|
local reconstract, reconstract_ch = require './lib/reconstract'
|
|
@@ -77,10 +77,6 @@ local function train()
|
|
|
learningRate = settings.learning_rate,
|
|
|
xBatchSize = settings.batch_size,
|
|
|
}
|
|
|
- local denoise_model = nil
|
|
|
- if settings.method == "scale" and path.exists(settings.denoise_model_file) then
|
|
|
- denoise_model = torch.load(settings.denoise_model_file)
|
|
|
- end
|
|
|
local transformer = function(x, is_validation)
|
|
|
if is_validation == nil then is_validation = false end
|
|
|
if settings.method == "scale" then
|
|
@@ -109,11 +105,11 @@ local function train()
|
|
|
for epoch = 1, settings.epoch do
|
|
|
model:training()
|
|
|
print("# " .. epoch)
|
|
|
- print(minibatch_sgd(model, criterion, train_x, adam_config,
|
|
|
- transformer,
|
|
|
- {1, settings.crop_size, settings.crop_size},
|
|
|
- {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
|
|
|
- ))
|
|
|
+ print(minibatch_adam(model, criterion, train_x, adam_config,
|
|
|
+ transformer,
|
|
|
+ {1, settings.crop_size, settings.crop_size},
|
|
|
+ {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
|
|
|
+ ))
|
|
|
if epoch % 1 == 0 then
|
|
|
collectgarbage()
|
|
|
model:evaluate()
|