Browse Source

Fix some multhread bug

nagadomi 8 years ago
parent
commit
06e08253e4
1 changed files with 12 additions and 6 deletions
  1. 12 6
      train.lua

+ 12 - 6
train.lua

@@ -58,10 +58,14 @@ local function transform_pool_init(has_resize, offset)
 	 require 'pl'
 	 local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
 	 package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
+	 require 'torch'
 	 require 'nn'
 	 require 'cunn'
-	 local threads = require 'threads'
 
+	 torch.setnumthreads(1)
+	 torch.setdefaulttensortype("torch.FloatTensor")
+
+	 local threads = require 'threads'
 	 local compression = require 'compression'
 	 local pairwise_transform = require 'pairwise_transform'
 
@@ -203,7 +207,6 @@ local function make_validation_set(x, n, patches)
 	 g_transform_pool:addjob(
 	    function()
 	       local xy = transformer(input, true, validation_patches)
-	       collectgarbage()
 	       return xy
 	    end,
 	    function(xy)
@@ -213,8 +216,11 @@ local function make_validation_set(x, n, patches)
 	    end
 	 )
       end
-      g_transform_pool:synchronize()
-      xlua.progress(i, #x)
+      if i % 20 == 0 then
+	 collectgarbage()
+	 g_transform_pool:synchronize()
+	 xlua.progress(i, #x)
+      end
    end
    g_transform_pool:synchronize()
    torch.setnumthreads(nthread) -- revert
@@ -311,9 +317,9 @@ local function resampling(x, y, train_x)
 	 end
       )
       if t % 50 == 0 then
-	 xlua.progress(t, #train_x)
-	 g_transform_pool:synchronize()
 	 collectgarbage()
+	 g_transform_pool:synchronize()
+	 xlua.progress(t, #train_x)
       end
       if c > x:size(1) then
 	 break