|
@@ -35,18 +35,19 @@ local function split_data(x, test_size)
|
|
|
end
|
|
|
return train_x, valid_x
|
|
|
end
|
|
|
-local function make_validation_set(x, transformer, n)
|
|
|
+local function make_validation_set(x, transformer, n, batch_size)
|
|
|
n = n or 4
|
|
|
local data = {}
|
|
|
for i = 1, #x do
|
|
|
- for k = 1, math.max(n / 8, 1) do
|
|
|
- local xy = transformer(x[i], true, 8)
|
|
|
+ for k = 1, math.max(n / batch_size, 1) do
|
|
|
+ local xy = transformer(x[i], true, batch_size)
|
|
|
+ local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
|
|
|
+ local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
|
|
|
for j = 1, #xy do
|
|
|
- local x = xy[j][1]
|
|
|
- local y = xy[j][2]
|
|
|
- table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
|
|
|
- y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
|
|
|
+ tx[j]:copy(xy[j][1])
|
|
|
+ ty[j]:copy(xy[j][2])
|
|
|
end
|
|
|
+ table.insert(data, {x = tx, y = ty})
|
|
|
end
|
|
|
xlua.progress(i, #x)
|
|
|
collectgarbage()
|
|
@@ -58,11 +59,12 @@ local function validate(model, criterion, data)
|
|
|
for i = 1, #data do
|
|
|
local z = model:forward(data[i].x:cuda())
|
|
|
loss = loss + criterion:forward(z, data[i].y:cuda())
|
|
|
- xlua.progress(i, #data)
|
|
|
- if i % 10 == 0 then
|
|
|
+ if i % 100 == 0 then
|
|
|
+ xlua.progress(i, #data)
|
|
|
collectgarbage()
|
|
|
end
|
|
|
end
|
|
|
+ xlua.progress(#data, #data)
|
|
|
return loss / #data
|
|
|
end
|
|
|
|
|
@@ -71,10 +73,10 @@ local function create_criterion(model)
|
|
|
local offset = reconstruct.offset_size(model)
|
|
|
local output_w = settings.crop_size - offset * 2
|
|
|
local weight = torch.Tensor(3, output_w * output_w)
|
|
|
- weight[1]:fill(0.299 * 3) -- R
|
|
|
- weight[2]:fill(0.587 * 3) -- G
|
|
|
- weight[3]:fill(0.114 * 3) -- B
|
|
|
- return w2nn.WeightedMSECriterion(weight):cuda()
|
|
|
+ weight[1]:fill(0.29891 * 3) -- R
|
|
|
+ weight[2]:fill(0.58661 * 3) -- G
|
|
|
+ weight[3]:fill(0.11448 * 3) -- B
|
|
|
+ return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
|
|
|
else
|
|
|
return nn.MSECriterion():cuda()
|
|
|
end
|
|
@@ -151,7 +153,9 @@ local function train()
|
|
|
end
|
|
|
local best_score = 100000.0
|
|
|
print("# make validation-set")
|
|
|
- local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops)
|
|
|
+ local valid_xy = make_validation_set(valid_x, pairwise_func,
|
|
|
+ settings.validation_crops,
|
|
|
+ settings.batch_size)
|
|
|
valid_x = nil
|
|
|
|
|
|
collectgarbage()
|