Parcourir la source

Fix the number of threads

nagadomi il y a 8 ans
Parent
commit
33e6bc888e
1 fichiers modifiés avec 13 ajouts et 3 suppressions
  1. 13 3
      train.lua

+ 13 - 3
train.lua

@@ -44,8 +44,12 @@ end
 
 local g_transform_pool = nil
 local function transform_pool_init(has_resize, offset)
+   local nthread = torch.getnumthreads()
+   if (settings.thread > 0) then
+      nthread = settings.thread
+   end
    g_transform_pool = threads.Threads(
-      torch.getnumthreads(),
+      nthread,
       function(threadid)
 	 require 'pl'
 	 local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
@@ -161,6 +165,9 @@ end
 
 local function make_validation_set(x, n, patches)
    local nthread = torch.getnumthreads()
+   if (settings.thread > 0) then
+      nthread = settings.thread
+   end
    n = n or 4
    local validation_patches = math.min(16, patches or 16)
    local data = {}
@@ -255,10 +262,13 @@ end
 
 local function resampling(x, y, train_x)
    local c = 1
-   local nthread = torch.getnumthreads()
    local shuffle = torch.randperm(#train_x)
-
+   local nthread = torch.getnumthreads()
+   if (settings.thread > 0) then
+      nthread = settings.thread
+   end
    torch.setnumthreads(1) -- 1
+
    for t = 1, #train_x do
       local input = train_x[shuffle[t]]
       g_transform_pool:addjob(