nagadomi hace 8 años
padre
commit
a14e6acec3
Se han modificado 1 ficheros con 10 adiciones y 0 borrados
  1. 10 0
      train.lua

+ 10 - 0
train.lua

@@ -43,11 +43,15 @@ local function split_data(x, test_size)
 end
 
 local g_transform_pool = nil
+local g_mutex = nil
+local g_mutex_id = nil
 local function transform_pool_init(has_resize, offset)
    local nthread = torch.getnumthreads()
    if (settings.thread > 0) then
       nthread = settings.thread
    end
+   g_mutex = threads.Mutex()
+   g_mutex_id = g_mutex:id()
    g_transform_pool = threads.Threads(
       nthread,
       function(threadid)
@@ -56,10 +60,13 @@ local function transform_pool_init(has_resize, offset)
 	 package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
 	 require 'nn'
 	 require 'cunn'
+	 local threads = require 'threads'
+
 	 local compression = require 'compression'
 	 local pairwise_transform = require 'pairwise_transform'
 
 	 function transformer(x, is_validation, n)
+	    local mutex = threads.Mutex(g_mutex_id)
 	    local meta = {data = {}}
 	    local y = nil
 	    if type(x) == "table" and type(x[2]) == "table" then
@@ -92,6 +99,7 @@ local function transform_pool_init(has_resize, offset)
 	    end
 	    if settings.method == "scale" then
 	       local conf = tablex.update({
+		     mutex = mutex,
 		     downsampling_filters = settings.downsampling_filters,
 		     random_half_rate = settings.random_half_rate,
 		     random_color_noise_rate = random_color_noise_rate,
@@ -114,6 +122,7 @@ local function transform_pool_init(has_resize, offset)
 					       n, conf)
 	    elseif settings.method == "noise" then
 	       local conf = tablex.update({
+		     mutex = mutex,
 		     random_half_rate = settings.random_half_rate,
 		     random_color_noise_rate = random_color_noise_rate,
 		     random_overlay_rate = random_overlay_rate,
@@ -135,6 +144,7 @@ local function transform_pool_init(has_resize, offset)
 					      n, conf)
 	    elseif settings.method == "noise_scale" then
 	       local conf = tablex.update({
+		     mutex = mutex,
 		     downsampling_filters = settings.downsampling_filters,
 		     random_half_rate = settings.random_half_rate,
 		     random_color_noise_rate = random_color_noise_rate,