|
@@ -58,10 +58,14 @@ local function transform_pool_init(has_resize, offset)
|
|
|
require 'pl'
|
|
|
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
|
|
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
|
|
+ require 'torch'
|
|
|
require 'nn'
|
|
|
require 'cunn'
|
|
|
- local threads = require 'threads'
|
|
|
|
|
|
+ torch.setnumthreads(1)
|
|
|
+ torch.setdefaulttensortype("torch.FloatTensor")
|
|
|
+
|
|
|
+ local threads = require 'threads'
|
|
|
local compression = require 'compression'
|
|
|
local pairwise_transform = require 'pairwise_transform'
|
|
|
|
|
@@ -203,7 +207,6 @@ local function make_validation_set(x, n, patches)
|
|
|
g_transform_pool:addjob(
|
|
|
function()
|
|
|
local xy = transformer(input, true, validation_patches)
|
|
|
- collectgarbage()
|
|
|
return xy
|
|
|
end,
|
|
|
function(xy)
|
|
@@ -213,8 +216,11 @@ local function make_validation_set(x, n, patches)
|
|
|
end
|
|
|
)
|
|
|
end
|
|
|
- g_transform_pool:synchronize()
|
|
|
- xlua.progress(i, #x)
|
|
|
+ if i % 20 == 0 then
|
|
|
+ collectgarbage()
|
|
|
+ g_transform_pool:synchronize()
|
|
|
+ xlua.progress(i, #x)
|
|
|
+ end
|
|
|
end
|
|
|
g_transform_pool:synchronize()
|
|
|
torch.setnumthreads(nthread) -- revert
|
|
@@ -311,9 +317,9 @@ local function resampling(x, y, train_x)
|
|
|
end
|
|
|
)
|
|
|
if t % 50 == 0 then
|
|
|
- xlua.progress(t, #train_x)
|
|
|
- g_transform_pool:synchronize()
|
|
|
collectgarbage()
|
|
|
+ g_transform_pool:synchronize()
|
|
|
+ xlua.progress(t, #train_x)
|
|
|
end
|
|
|
if c > x:size(1) then
|
|
|
break
|