|
@@ -44,8 +44,12 @@ end
|
|
|
|
|
|
local g_transform_pool = nil
|
|
|
local function transform_pool_init(has_resize, offset)
|
|
|
+ local nthread = torch.getnumthreads()
|
|
|
+ if (settings.thread > 0) then
|
|
|
+ nthread = settings.thread
|
|
|
+ end
|
|
|
g_transform_pool = threads.Threads(
|
|
|
- torch.getnumthreads(),
|
|
|
+ nthread,
|
|
|
function(threadid)
|
|
|
require 'pl'
|
|
|
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
|
@@ -161,6 +165,9 @@ end
|
|
|
|
|
|
local function make_validation_set(x, n, patches)
|
|
|
local nthread = torch.getnumthreads()
|
|
|
+ if (settings.thread > 0) then
|
|
|
+ nthread = settings.thread
|
|
|
+ end
|
|
|
n = n or 4
|
|
|
local validation_patches = math.min(16, patches or 16)
|
|
|
local data = {}
|
|
@@ -255,10 +262,13 @@ end
|
|
|
|
|
|
local function resampling(x, y, train_x)
|
|
|
local c = 1
|
|
|
- local nthread = torch.getnumthreads()
|
|
|
local shuffle = torch.randperm(#train_x)
|
|
|
-
|
|
|
+ local nthread = torch.getnumthreads()
|
|
|
+ if (settings.thread > 0) then
|
|
|
+ nthread = settings.thread
|
|
|
+ end
|
|
|
torch.setnumthreads(1) -- 1
|
|
|
+
|
|
|
for t = 1, #train_x do
|
|
|
local input = train_x[shuffle[t]]
|
|
|
g_transform_pool:addjob(
|