|
@@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file)
|
|
|
end
|
|
|
end
|
|
|
local function split_data(x, test_size)
|
|
|
- local index = torch.randperm(#x)
|
|
|
- local train_size = #x - test_size
|
|
|
- local train_x = {}
|
|
|
- local valid_x = {}
|
|
|
- for i = 1, train_size do
|
|
|
- train_x[i] = x[index[i]]
|
|
|
- end
|
|
|
- for i = 1, test_size do
|
|
|
- valid_x[i] = x[index[train_size + i]]
|
|
|
+ if settings.validation_filename_split then
|
|
|
+ if not (x[1][2].data and x[1][2].data.basename) then
|
|
|
+ error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
|
|
|
+ end
|
|
|
+ local basename_db = {}
|
|
|
+ for i = 1, #x do
|
|
|
+ local meta = x[i][2].data
|
|
|
+ if basename_db[meta.basename] then
|
|
|
+ table.insert(basename_db[meta.basename], x[i])
|
|
|
+ else
|
|
|
+ basename_db[meta.basename] = {x[i]}
|
|
|
+ end
|
|
|
+ end
|
|
|
+ local basename_list = {}
|
|
|
+ for k, v in pairs(basename_db) do
|
|
|
+ table.insert(basename_list, v)
|
|
|
+ end
|
|
|
+ local index = torch.randperm(#basename_list)
|
|
|
+ local train_x = {}
|
|
|
+ local valid_x = {}
|
|
|
+ local pos = 1
|
|
|
+ for i = 1, #basename_list do
|
|
|
+ if #valid_x >= test_size then
|
|
|
+ break
|
|
|
+ end
|
|
|
+ local xs = basename_list[index[pos]]
|
|
|
+ for j = 1, #xs do
|
|
|
+ table.insert(valid_x, xs[j])
|
|
|
+ end
|
|
|
+ pos = pos + 1
|
|
|
+ end
|
|
|
+ for i = pos, #basename_list do
|
|
|
+ local xs = basename_list[index[i]]
|
|
|
+ for j = 1, #xs do
|
|
|
+ table.insert(train_x, xs[j])
|
|
|
+ end
|
|
|
+ end
|
|
|
+ return train_x, valid_x
|
|
|
+ else
|
|
|
+ local index = torch.randperm(#x)
|
|
|
+ local train_size = #x - test_size
|
|
|
+ local train_x = {}
|
|
|
+ local valid_x = {}
|
|
|
+ for i = 1, train_size do
|
|
|
+ train_x[i] = x[index[i]]
|
|
|
+ end
|
|
|
+ for i = 1, test_size do
|
|
|
+ valid_x[i] = x[index[train_size + i]]
|
|
|
+ end
|
|
|
+ return train_x, valid_x
|
|
|
end
|
|
|
- return train_x, valid_x
|
|
|
end
|
|
|
|
|
|
local g_transform_pool = nil
|