|
@@ -43,11 +43,15 @@ local function split_data(x, test_size)
|
|
|
end
|
|
|
|
|
|
local g_transform_pool = nil
|
|
|
+local g_mutex = nil
|
|
|
+local g_mutex_id = nil
|
|
|
local function transform_pool_init(has_resize, offset)
|
|
|
local nthread = torch.getnumthreads()
|
|
|
if (settings.thread > 0) then
|
|
|
nthread = settings.thread
|
|
|
end
|
|
|
+ g_mutex = threads.Mutex()
|
|
|
+ g_mutex_id = g_mutex:id()
|
|
|
g_transform_pool = threads.Threads(
|
|
|
nthread,
|
|
|
function(threadid)
|
|
@@ -56,10 +60,13 @@ local function transform_pool_init(has_resize, offset)
|
|
|
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
|
|
require 'nn'
|
|
|
require 'cunn'
|
|
|
+ local threads = require 'threads'
|
|
|
+
|
|
|
local compression = require 'compression'
|
|
|
local pairwise_transform = require 'pairwise_transform'
|
|
|
|
|
|
function transformer(x, is_validation, n)
|
|
|
+ local mutex = threads.Mutex(g_mutex_id)
|
|
|
local meta = {data = {}}
|
|
|
local y = nil
|
|
|
if type(x) == "table" and type(x[2]) == "table" then
|
|
@@ -92,6 +99,7 @@ local function transform_pool_init(has_resize, offset)
|
|
|
end
|
|
|
if settings.method == "scale" then
|
|
|
local conf = tablex.update({
|
|
|
+ mutex = mutex,
|
|
|
downsampling_filters = settings.downsampling_filters,
|
|
|
random_half_rate = settings.random_half_rate,
|
|
|
random_color_noise_rate = random_color_noise_rate,
|
|
@@ -114,6 +122,7 @@ local function transform_pool_init(has_resize, offset)
|
|
|
n, conf)
|
|
|
elseif settings.method == "noise" then
|
|
|
local conf = tablex.update({
|
|
|
+ mutex = mutex,
|
|
|
random_half_rate = settings.random_half_rate,
|
|
|
random_color_noise_rate = random_color_noise_rate,
|
|
|
random_overlay_rate = random_overlay_rate,
|
|
@@ -135,6 +144,7 @@ local function transform_pool_init(has_resize, offset)
|
|
|
n, conf)
|
|
|
elseif settings.method == "noise_scale" then
|
|
|
local conf = tablex.update({
|
|
|
+ mutex = mutex,
|
|
|
downsampling_filters = settings.downsampling_filters,
|
|
|
random_half_rate = settings.random_half_rate,
|
|
|
random_color_noise_rate = random_color_noise_rate,
|