|
@@ -41,31 +41,49 @@ local function make_validation_set(x, transformer, n, patches)
|
|
|
for i = 1, #x do
|
|
|
for k = 1, math.max(n / patches, 1) do
|
|
|
local xy = transformer(x[i], true, patches)
|
|
|
- local tx = torch.Tensor(patches, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
|
|
|
- local ty = torch.Tensor(patches, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
|
|
|
for j = 1, #xy do
|
|
|
- tx[j]:copy(xy[j][1])
|
|
|
- ty[j]:copy(xy[j][2])
|
|
|
+ table.insert(data, {x = xy[j][1], y = xy[j][2]})
|
|
|
end
|
|
|
- table.insert(data, {x = tx, y = ty})
|
|
|
end
|
|
|
xlua.progress(i, #x)
|
|
|
collectgarbage()
|
|
|
end
|
|
|
return data
|
|
|
end
|
|
|
-local function validate(model, criterion, data)
|
|
|
+local function validate(model, criterion, data, batch_size)
|
|
|
local loss = 0
|
|
|
- for i = 1, #data do
|
|
|
- local z = model:forward(data[i].x:cuda())
|
|
|
- loss = loss + criterion:forward(z, data[i].y:cuda())
|
|
|
- if i % 100 == 0 then
|
|
|
- xlua.progress(i, #data)
|
|
|
+ local loss_count = 0
|
|
|
+ local inputs_tmp = torch.Tensor(batch_size,
|
|
|
+ data[1].x:size(1),
|
|
|
+ data[1].x:size(2),
|
|
|
+ data[1].x:size(3)):zero()
|
|
|
+ local targets_tmp = torch.Tensor(batch_size,
|
|
|
+ data[1].y:size(1),
|
|
|
+ data[1].y:size(2),
|
|
|
+ data[1].y:size(3)):zero()
|
|
|
+ local inputs = inputs_tmp:clone():cuda()
|
|
|
+ local targets = targets_tmp:clone():cuda()
|
|
|
+
|
|
|
+ for t = 1, #data, batch_size do
|
|
|
+ if t + batch_size -1 > #data then
|
|
|
+ break
|
|
|
+ end
|
|
|
+ for i = 1, batch_size do
|
|
|
+ inputs_tmp[i]:copy(data[t + i - 1].x)
|
|
|
+ targets_tmp[i]:copy(data[t + i - 1].y)
|
|
|
+ end
|
|
|
+ inputs:copy(inputs_tmp)
|
|
|
+ targets:copy(targets_tmp)
|
|
|
+ local z = model:forward(inputs)
|
|
|
+ loss = loss + criterion:forward(z, targets)
|
|
|
+ loss_count = loss_count + 1
|
|
|
+ if t % 10 == 0 then
|
|
|
+ xlua.progress(t, #data)
|
|
|
collectgarbage()
|
|
|
end
|
|
|
end
|
|
|
xlua.progress(#data, #data)
|
|
|
- return loss / #data
|
|
|
+ return loss / loss_count
|
|
|
end
|
|
|
|
|
|
local function create_criterion(model)
|
|
@@ -214,8 +232,7 @@ local function train()
|
|
|
print(train_score)
|
|
|
model:evaluate()
|
|
|
print("# validation")
|
|
|
- local score = validate(model, eval_metric, valid_xy)
|
|
|
-
|
|
|
+ local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize)
|
|
|
table.insert(hist_train, train_score.PSNR)
|
|
|
table.insert(hist_valid, score)
|
|
|
if settings.plot then
|