Browse Source

Support for multi-thread in training

And remove `sys`,`image` and `graphicsmagicks.conveter` from the training code because those causes the deadlock on thread package.
nagadomi 8 năm trước cách đây
mục cha
commit
c2e4bb4380

+ 5 - 6
lib/data_augmentation.lua

@@ -1,7 +1,6 @@
-require 'image'
 local iproc = require 'iproc'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
-
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local data_augmentation = {}
 local data_augmentation = {}
 
 
 local function pcacov(x)
 local function pcacov(x)
@@ -107,11 +106,11 @@ function data_augmentation.flip(src)
       src = src:transpose(2, 3):contiguous()
       src = src:transpose(2, 3):contiguous()
    end
    end
    if flip == 1 then
    if flip == 1 then
-      dest = image.hflip(src)
+      dest = iproc.hflip(src)
    elseif flip == 2 then
    elseif flip == 2 then
-      dest = image.vflip(src)
+      dest = iproc.vflip(src)
    elseif flip == 3 then
    elseif flip == 3 then
-      dest = image.hflip(image.vflip(src))
+      dest = iproc.hflip(iproc.vflip(src))
    elseif flip == 4 then
    elseif flip == 4 then
       dest = src
       dest = src
    end
    end

+ 148 - 2
lib/iproc.lua

@@ -1,5 +1,6 @@
-local gm = require 'graphicsmagick'
-local image = require 'image'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
+local image = nil
 
 
 local iproc = {}
 local iproc = {}
 local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
 local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
@@ -80,6 +81,7 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
    return dest
    return dest
 end
 end
 function iproc.padding(img, w1, w2, h1, h2)
 function iproc.padding(img, w1, w2, h1, h2)
+   image = image or require 'image'
    local dst_height = img:size(2) + h1 + h2
    local dst_height = img:size(2) + h1 + h2
    local dst_width = img:size(3) + w1 + w2
    local dst_width = img:size(3) + w1 + w2
    local flow = torch.Tensor(2, dst_height, dst_width)
    local flow = torch.Tensor(2, dst_height, dst_width)
@@ -90,6 +92,7 @@ function iproc.padding(img, w1, w2, h1, h2)
    return image.warp(img, flow, "simple", false, "clamp")
    return image.warp(img, flow, "simple", false, "clamp")
 end
 end
 function iproc.zero_padding(img, w1, w2, h1, h2)
 function iproc.zero_padding(img, w1, w2, h1, h2)
+   image = image or require 'image'
    local dst_height = img:size(2) + h1 + h2
    local dst_height = img:size(2) + h1 + h2
    local dst_width = img:size(3) + w1 + w2
    local dst_width = img:size(3) + w1 + w2
    local flow = torch.Tensor(2, dst_height, dst_width)
    local flow = torch.Tensor(2, dst_height, dst_width)
@@ -126,6 +129,131 @@ function iproc.white_noise(src, std, rgb_weights, gamma)
    end
    end
    return dest
    return dest
 end
 end
+function iproc.hflip(src)
+   local t
+   if src:type() == "torch.ByteTensor" then
+      t = "byte"
+   else
+      t = "float"
+   end
+   if src:size(1) == 3 then
+      color = "RGB"
+   else
+      color = "I"
+   end
+   local im = gm.Image(src, color, "DHW")
+   return im:flop():toTensor(t, color, "DHW")
+end
+function iproc.vflip(src)
+   local t
+   if src:type() == "torch.ByteTensor" then
+      t = "byte"
+   else
+      t = "float"
+   end
+   if src:size(1) == 3 then
+      color = "RGB"
+   else
+      color = "I"
+   end
+   local im = gm.Image(src, color, "DHW")
+   return im:flip():toTensor(t, color, "DHW")
+end
+
+-- from torch/image
+----------------------------------------------------------------------
+-- image.rgb2yuv(image)
+-- converts a RGB image to YUV
+--
+function iproc.rgb2yuv(...)
+   -- arg check
+   local output,input
+   local args = {...}
+   if select('#',...) == 2 then
+      output = args[1]
+      input = args[2]
+   elseif select('#',...) == 1 then
+      input = args[1]
+   else
+      print(dok.usage('image.rgb2yuv',
+                      'transforms an image from RGB to YUV', nil,
+                      {type='torch.Tensor', help='input image', req=true},
+                      '',
+                      {type='torch.Tensor', help='output image', req=true},
+                      {type='torch.Tensor', help='input image', req=true}
+                      ))
+      dok.error('missing input', 'image.rgb2yuv')
+   end
+
+   -- resize
+   output = output or input.new()
+   output:resizeAs(input)
+
+   -- input chanels
+   local inputRed = input[1]
+   local inputGreen = input[2]
+   local inputBlue = input[3]
+
+   -- output chanels
+   local outputY = output[1]
+   local outputU = output[2]
+   local outputV = output[3]
+
+   -- convert
+   outputY:zero():add(0.299, inputRed):add(0.587, inputGreen):add(0.114, inputBlue)
+   outputU:zero():add(-0.14713, inputRed):add(-0.28886, inputGreen):add(0.436, inputBlue)
+   outputV:zero():add(0.615, inputRed):add(-0.51499, inputGreen):add(-0.10001, inputBlue)
+
+   -- return YUV image
+   return output
+end
+
+----------------------------------------------------------------------
+-- image.yuv2rgb(image)
+-- converts a YUV image to RGB
+--
+function iproc.yuv2rgb(...)
+   -- arg check
+   local output,input
+   local args = {...}
+   if select('#',...) == 2 then
+      output = args[1]
+      input = args[2]
+   elseif select('#',...) == 1 then
+      input = args[1]
+   else
+      print(dok.usage('image.yuv2rgb',
+                      'transforms an image from YUV to RGB', nil,
+                      {type='torch.Tensor', help='input image', req=true},
+                      '',
+                      {type='torch.Tensor', help='output image', req=true},
+                      {type='torch.Tensor', help='input image', req=true}
+                      ))
+      dok.error('missing input', 'image.yuv2rgb')
+   end
+
+   -- resize
+   output = output or input.new()
+   output:resizeAs(input)
+
+   -- input chanels
+   local inputY = input[1]
+   local inputU = input[2]
+   local inputV = input[3]
+
+   -- output chanels
+   local outputRed = output[1]
+   local outputGreen = output[2]
+   local outputBlue = output[3]
+
+   -- convert
+   outputRed:copy(inputY):add(1.13983, inputV)
+   outputGreen:copy(inputY):add(-0.39465, inputU):add(-0.58060, inputV)
+   outputBlue:copy(inputY):add(2.03211, inputU)
+
+   -- return RGB image
+   return output
+end
 
 
 local function test_conversion()
 local function test_conversion()
    local a = torch.linspace(0, 255, 256):float():div(255.0)
    local a = torch.linspace(0, 255, 256):float():div(255.0)
@@ -144,6 +272,24 @@ local function test_conversion()
    print(b)
    print(b)
    assert(b:float():sum() == 254.0 * 3)
    assert(b:float():sum() == 254.0 * 3)
 end
 end
+local function test_flip()
+   require 'sys'
+   require 'torch'
+   torch.setdefaulttensortype("torch.FloatTensor")
+   image = require 'image'
+   local src = image.lena()
+   local src_byte = src:clone():mul(255):byte()
+
+   print(src:size())
+   print((image.hflip(src) - iproc.hflip(src)):sum())
+   print((image.hflip(src_byte) - iproc.hflip(src_byte)):sum())
+   print((image.vflip(src) - iproc.vflip(src)):sum())
+   print((image.vflip(src_byte) - iproc.vflip(src_byte)):sum())
+end
+
 --test_conversion()
 --test_conversion()
+--test_flip()
 
 
 return iproc
 return iproc
+
+

+ 4 - 3
lib/pairwise_transform_jpeg.lua

@@ -1,5 +1,6 @@
 local pairwise_utils = require 'pairwise_transform_utils'
 local pairwise_utils = require 'pairwise_transform_utils'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local iproc = require 'iproc'
 local iproc = require 'iproc'
 local pairwise_transform = {}
 local pairwise_transform = {}
 
 
@@ -42,8 +43,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
       yc = iproc.byte2float(yc)
       yc = iproc.byte2float(yc)
       if options.rgb then
       if options.rgb then
       else
       else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+	 yc = iproc.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = iproc.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
       end
       end
       if torch.uniform() < options.nr_rate then
       if torch.uniform() < options.nr_rate then
 	 -- reducing noise
 	 -- reducing noise

+ 4 - 3
lib/pairwise_transform_jpeg_scale.lua

@@ -1,6 +1,7 @@
 local pairwise_utils = require 'pairwise_transform_utils'
 local pairwise_utils = require 'pairwise_transform_utils'
 local iproc = require 'iproc'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local pairwise_transform = {}
 local pairwise_transform = {}
 
 
 local function add_jpeg_noise_(x, quality, options)
 local function add_jpeg_noise_(x, quality, options)
@@ -117,8 +118,8 @@ function pairwise_transform.jpeg_scale(src, scale, style, noise_level, size, off
       yc = iproc.byte2float(yc)
       yc = iproc.byte2float(yc)
       if options.rgb then
       if options.rgb then
       else
       else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+	 yc = iproc.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = iproc.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
       end
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end
    end

+ 4 - 3
lib/pairwise_transform_scale.lua

@@ -1,6 +1,7 @@
 local pairwise_utils = require 'pairwise_transform_utils'
 local pairwise_utils = require 'pairwise_transform_utils'
 local iproc = require 'iproc'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local pairwise_transform = {}
 local pairwise_transform = {}
 
 
 function pairwise_transform.scale(src, scale, size, offset, n, options)
 function pairwise_transform.scale(src, scale, size, offset, n, options)
@@ -50,8 +51,8 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
       yc = iproc.byte2float(yc)
       yc = iproc.byte2float(yc)
       if options.rgb then
       if options.rgb then
       else
       else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+	 yc = iproc.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = iproc.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
       end
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end
    end

+ 4 - 3
lib/pairwise_transform_user.lua

@@ -1,6 +1,7 @@
 local pairwise_utils = require 'pairwise_transform_utils'
 local pairwise_utils = require 'pairwise_transform_utils'
 local iproc = require 'iproc'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local pairwise_transform = {}
 local pairwise_transform = {}
 
 
 local function crop_if_large(x, y, scale_y, max_size, mod)
 local function crop_if_large(x, y, scale_y, max_size, mod)
@@ -47,8 +48,8 @@ function pairwise_transform.user(x, y, size, offset, n, options)
       yc = iproc.byte2float(yc)
       yc = iproc.byte2float(yc)
       if options.rgb then
       if options.rgb then
       else
       else
-	 yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
-	 xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
+	 yc = iproc.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
+	 xc = iproc.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
       end
       end
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
       table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
    end
    end

+ 22 - 14
lib/pairwise_transform_utils.lua

@@ -1,7 +1,7 @@
-require 'image'
 require 'cunn'
 require 'cunn'
 local iproc = require 'iproc'
 local iproc = require 'iproc'
-local gm = require 'graphicsmagick'
+local gm = {}
+gm.Image = require 'graphicsmagick.Image'
 local data_augmentation = require 'data_augmentation'
 local data_augmentation = require 'data_augmentation'
 local pairwise_transform_utils = {}
 local pairwise_transform_utils = {}
 
 
@@ -125,13 +125,13 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
 	 yi = y:transpose(2, 3):contiguous()
 	 yi = y:transpose(2, 3):contiguous()
 	 ri = lowres_y:transpose(2, 3):contiguous()
 	 ri = lowres_y:transpose(2, 3):contiguous()
       end
       end
-      local xv = image.vflip(xi)
+      local xv = iproc.vflip(xi)
       local nv
       local nv
       if x_noise then
       if x_noise then
-	 nv = image.vflip(ni)
+	 nv = iproc.vflip(ni)
       end
       end
-      local yv = image.vflip(yi)
-      local rv = image.vflip(ri)
+      local yv = iproc.vflip(yi)
+      local rv = iproc.vflip(ri)
       table.insert(xs, xi)
       table.insert(xs, xi)
       if ni then
       if ni then
 	 table.insert(ns, ni)
 	 table.insert(ns, ni)
@@ -146,19 +146,19 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
       table.insert(ys, yv)
       table.insert(ys, yv)
       table.insert(ls, rv)
       table.insert(ls, rv)
 
 
-      table.insert(xs, image.hflip(xi))
+      table.insert(xs, iproc.hflip(xi))
       if ni then
       if ni then
-	 table.insert(ns, image.hflip(ni))
+	 table.insert(ns, iproc.hflip(ni))
       end
       end
-      table.insert(ys, image.hflip(yi))
-      table.insert(ls, image.hflip(ri))
+      table.insert(ys, iproc.hflip(yi))
+      table.insert(ls, iproc.hflip(ri))
 
 
-      table.insert(xs, image.hflip(xv))
+      table.insert(xs, iproc.hflip(xv))
       if nv then
       if nv then
-	 table.insert(ns, image.hflip(nv))
+	 table.insert(ns, iproc.hflip(nv))
       end
       end
-      table.insert(ys, image.hflip(yv))
-      table.insert(ls, image.hflip(rv))
+      table.insert(ys, iproc.hflip(yv))
+      table.insert(ls, iproc.hflip(rv))
    end
    end
    return xs, ys, ls, ns
    return xs, ys, ls, ns
 end
 end
@@ -171,6 +171,9 @@ end
 local g_lowres_model = nil
 local g_lowres_model = nil
 local g_lowres_gpu = nil
 local g_lowres_gpu = nil
 function pairwise_transform_utils.low_resolution(src)
 function pairwise_transform_utils.low_resolution(src)
+--[[
+   -- I am not sure that the following process is thraed-safe
+
    g_lowres_model = g_lowres_model or lowres_model()
    g_lowres_model = g_lowres_model or lowres_model()
    if g_lowres_gpu == nil then
    if g_lowres_gpu == nil then
       --benchmark
       --benchmark
@@ -203,6 +206,11 @@ function pairwise_transform_utils.low_resolution(src)
 	 size(src:size(3), src:size(2), "Box"):
 	 size(src:size(3), src:size(2), "Box"):
 	    toTensor("byte", "RGB", "DHW")
 	    toTensor("byte", "RGB", "DHW")
    end
    end
+--]]
+   return gm.Image(src, "RGB", "DHW"):
+      size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
+      size(src:size(3), src:size(2), "Box"):
+      toTensor("byte", "RGB", "DHW")
 end
 end
 
 
 return pairwise_transform_utils
 return pairwise_transform_utils

+ 176 - 126
train.lua

@@ -3,15 +3,14 @@ local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^
 package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
 package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
 require 'optim'
 require 'optim'
 require 'xlua'
 require 'xlua'
-
+require 'image'
 require 'w2nn'
 require 'w2nn'
+local threads = require 'threads'
 local settings = require 'settings'
 local settings = require 'settings'
 local srcnn = require 'srcnn'
 local srcnn = require 'srcnn'
 local minibatch_adam = require 'minibatch_adam'
 local minibatch_adam = require 'minibatch_adam'
 local iproc = require 'iproc'
 local iproc = require 'iproc'
 local reconstruct = require 'reconstruct'
 local reconstruct = require 'reconstruct'
-local compression = require 'compression'
-local pairwise_transform = require 'pairwise_transform'
 local image_loader = require 'image_loader'
 local image_loader = require 'image_loader'
 
 
 local function save_test_scale(model, rgb, file)
 local function save_test_scale(model, rgb, file)
@@ -42,20 +41,155 @@ local function split_data(x, test_size)
    end
    end
    return train_x, valid_x
    return train_x, valid_x
 end
 end
-local function make_validation_set(x, transformer, n, patches)
+
+local g_transform_pool = nil
+local function transform_pool_init(has_resize, offset)
+   g_transform_pool = threads.Threads(
+      torch.getnumthreads(),
+      function(threadid)
+	 require 'pl'
+	 local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+	 package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
+	 require 'nn'
+	 require 'cunn'
+	 local compression = require 'compression'
+	 local pairwise_transform = require 'pairwise_transform'
+
+	 function transformer(x, is_validation, n)
+	    local meta = {data = {}}
+	    local y = nil
+	    if type(x) == "table" and type(x[2]) == "table" then
+	       meta = x[2]
+	       if x[1].x and x[1].y then
+		  y = compression.decompress(x[1].y)
+		  x = compression.decompress(x[1].x)
+	       else
+		  x = compression.decompress(x[1])
+	       end
+	    else
+	       x = compression.decompress(x)
+	    end
+	    n = n or settings.patches
+	    if is_validation == nil then is_validation = false end
+	    local random_color_noise_rate = nil 
+	    local random_overlay_rate = nil
+	    local active_cropping_rate = nil
+	    local active_cropping_tries = nil
+	    if is_validation then
+	       active_cropping_rate = settings.active_cropping_rate
+	       active_cropping_tries = settings.active_cropping_tries
+	       random_color_noise_rate = 0.0
+	       random_overlay_rate = 0.0
+	    else
+	       active_cropping_rate = settings.active_cropping_rate
+	       active_cropping_tries = settings.active_cropping_tries
+	       random_color_noise_rate = settings.random_color_noise_rate
+	       random_overlay_rate = settings.random_overlay_rate
+	    end
+	    if settings.method == "scale" then
+	       local conf = tablex.update({
+		     downsampling_filters = settings.downsampling_filters,
+		     random_half_rate = settings.random_half_rate,
+		     random_color_noise_rate = random_color_noise_rate,
+		     random_overlay_rate = random_overlay_rate,
+		     random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
+		     max_size = settings.max_size,
+		     active_cropping_rate = active_cropping_rate,
+		     active_cropping_tries = active_cropping_tries,
+		     rgb = (settings.color == "rgb"),
+		     x_upsampling = not has_resize,
+		     resize_blur_min = settings.resize_blur_min,
+		     resize_blur_max = settings.resize_blur_max}, meta)
+	       return pairwise_transform.scale(x,
+					       settings.scale,
+					       settings.crop_size, offset,
+					       n, conf)
+	    elseif settings.method == "noise" then
+	       local conf = tablex.update({
+		     random_half_rate = settings.random_half_rate,
+		     random_color_noise_rate = random_color_noise_rate,
+		     random_overlay_rate = random_overlay_rate,
+		     random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
+		     max_size = settings.max_size,
+		     jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
+		     active_cropping_rate = active_cropping_rate,
+		     active_cropping_tries = active_cropping_tries,
+		     nr_rate = settings.nr_rate,
+		     rgb = (settings.color == "rgb")}, meta)
+	       return pairwise_transform.jpeg(x,
+					      settings.style,
+					      settings.noise_level,
+					      settings.crop_size, offset,
+					      n, conf)
+	    elseif settings.method == "noise_scale" then
+	       local conf = tablex.update({
+		     downsampling_filters = settings.downsampling_filters,
+		     random_half_rate = settings.random_half_rate,
+		     random_color_noise_rate = random_color_noise_rate,
+		     random_overlay_rate = random_overlay_rate,
+		     random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
+		     max_size = settings.max_size,
+		     jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
+		     nr_rate = settings.nr_rate,
+		     active_cropping_rate = active_cropping_rate,
+		     active_cropping_tries = active_cropping_tries,
+		     rgb = (settings.color == "rgb"),
+		     x_upsampling = not has_resize,
+		     resize_blur_min = settings.resize_blur_min,
+		     resize_blur_max = settings.resize_blur_max}, meta)
+	       return pairwise_transform.jpeg_scale(x,
+						    settings.scale,
+						    settings.style,
+						    settings.noise_level,
+						    settings.crop_size, offset,
+						    n, conf)
+	    elseif settings.method == "user" then
+	       local conf = tablex.update({
+		     max_size = settings.max_size,
+		     active_cropping_rate = active_cropping_rate,
+		     active_cropping_tries = active_cropping_tries,
+		     rgb = (settings.color == "rgb")}, meta)
+	       return pairwise_transform.user(x, y,
+					      settings.crop_size, offset,
+					      n, conf)
+	    end
+	 end
+      end
+   )
+   g_transform_pool:synchronize()
+end
+
+local function make_validation_set(x, n, patches)
+   local nthread = torch.getnumthreads()
    n = n or 4
    n = n or 4
    local validation_patches = math.min(16, patches or 16)
    local validation_patches = math.min(16, patches or 16)
    local data = {}
    local data = {}
+
+   g_transform_pool:synchronize()
+   torch.setnumthreads(1) -- 1
+
    for i = 1, #x do
    for i = 1, #x do
       for k = 1, math.max(n / validation_patches, 1) do
       for k = 1, math.max(n / validation_patches, 1) do
-	 local xy = transformer(x[i], true, validation_patches)
-	 for j = 1, #xy do
-	    table.insert(data, {x = xy[j][1], y = xy[j][2]})
-	 end
+	 local input = x[i]
+	 g_transform_pool:addjob(
+	    function()
+	       local xy = transformer(input, true, validation_patches)
+	       collectgarbage()
+	       return xy
+	    end,
+	    function(xy)
+	       for j = 1, #xy do
+		  table.insert(data, {x = xy[j][1], y = xy[j][2]})
+	       end
+	    end
+	 )
       end
       end
+      g_transform_pool:synchronize()
       xlua.progress(i, #x)
       xlua.progress(i, #x)
-      collectgarbage()
    end
    end
+   g_transform_pool:synchronize()
+   torch.setnumthreads(nthread) -- revert
+
    local new_data = {}
    local new_data = {}
    local perm = torch.randperm(#data)
    local perm = torch.randperm(#data)
    for i = 1, perm:size(1) do
    for i = 1, perm:size(1) do
@@ -118,128 +252,44 @@ local function create_criterion(model)
       return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
       return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
    end
    end
 end
 end
-local function transformer(model, x, is_validation, n, offset)
-   local meta = {data = {}}
-   local y = nil
-   if type(x) == "table" and type(x[2]) == "table" then
-      meta = x[2]
-      if x[1].x and x[1].y then
-	 y = compression.decompress(x[1].y)
-	 x = compression.decompress(x[1].x)
-      else
-	 x = compression.decompress(x[1])
-      end
-   else
-      x = compression.decompress(x)
-   end
-   n = n or settings.patches
-   if is_validation == nil then is_validation = false end
-   local random_color_noise_rate = nil 
-   local random_overlay_rate = nil
-   local active_cropping_rate = nil
-   local active_cropping_tries = nil
-   if is_validation then
-      active_cropping_rate = settings.active_cropping_rate
-      active_cropping_tries = settings.active_cropping_tries
-      random_color_noise_rate = 0.0
-      random_overlay_rate = 0.0
-   else
-      active_cropping_rate = settings.active_cropping_rate
-      active_cropping_tries = settings.active_cropping_tries
-      random_color_noise_rate = settings.random_color_noise_rate
-      random_overlay_rate = settings.random_overlay_rate
-   end
-   if settings.method == "scale" then
-      local conf = tablex.update({
-	    downsampling_filters = settings.downsampling_filters,
-	    random_half_rate = settings.random_half_rate,
-	    random_color_noise_rate = random_color_noise_rate,
-	    random_overlay_rate = random_overlay_rate,
-	    random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
-	    max_size = settings.max_size,
-	    active_cropping_rate = active_cropping_rate,
-	    active_cropping_tries = active_cropping_tries,
-	    rgb = (settings.color == "rgb"),
-	    x_upsampling = not reconstruct.has_resize(model),
-	    resize_blur_min = settings.resize_blur_min,
-	 resize_blur_max = settings.resize_blur_max}, meta)
-      return pairwise_transform.scale(x,
-				      settings.scale,
-				      settings.crop_size, offset,
-				      n, conf)
-   elseif settings.method == "noise" then
-      local conf = tablex.update({
-	    random_half_rate = settings.random_half_rate,
-	    random_color_noise_rate = random_color_noise_rate,
-	    random_overlay_rate = random_overlay_rate,
-	    random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
-	    max_size = settings.max_size,
-	    jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
-	    active_cropping_rate = active_cropping_rate,
-	    active_cropping_tries = active_cropping_tries,
-	    nr_rate = settings.nr_rate,
-	    rgb = (settings.color == "rgb")}, meta)
-      return pairwise_transform.jpeg(x,
-				     settings.style,
-				     settings.noise_level,
-				     settings.crop_size, offset,
-				     n, conf)
-   elseif settings.method == "noise_scale" then
-      local conf = tablex.update({
-	    downsampling_filters = settings.downsampling_filters,
-	    random_half_rate = settings.random_half_rate,
-	    random_color_noise_rate = random_color_noise_rate,
-	    random_overlay_rate = random_overlay_rate,
-	    random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
-	    max_size = settings.max_size,
-	    jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
-	    nr_rate = settings.nr_rate,
-	    active_cropping_rate = active_cropping_rate,
-	    active_cropping_tries = active_cropping_tries,
-	    rgb = (settings.color == "rgb"),
-	    x_upsampling = not reconstruct.has_resize(model),
-	    resize_blur_min = settings.resize_blur_min,
-	    resize_blur_max = settings.resize_blur_max}, meta)
-      return pairwise_transform.jpeg_scale(x,
-					   settings.scale,
-					   settings.style,
-					   settings.noise_level,
-					   settings.crop_size, offset,
-					   n, conf)
-   elseif settings.method == "user" then
-      local conf = tablex.update({
-	    max_size = settings.max_size,
-	    active_cropping_rate = active_cropping_rate,
-	    active_cropping_tries = active_cropping_tries,
-	    rgb = (settings.color == "rgb")}, meta)
-      return pairwise_transform.user(x, y,
-				     settings.crop_size, offset,
-				     n, conf)
-   end
-end
 
 
-local function resampling(x, y, train_x, transformer, input_size, target_size)
+local function resampling(x, y, train_x)
    local c = 1
    local c = 1
+   local nthread = torch.getnumthreads()
    local shuffle = torch.randperm(#train_x)
    local shuffle = torch.randperm(#train_x)
+
+   torch.setnumthreads(1) -- 1
    for t = 1, #train_x do
    for t = 1, #train_x do
-      xlua.progress(t, #train_x)
-      local xy = transformer(train_x[shuffle[t]], false, settings.patches)
-      for i = 1, #xy do
-         x[c]:copy(xy[i][1])
-	 y[c]:copy(xy[i][2])
-	 c = c + 1
-	 if c > x:size(1) then
-	    break
+      local input = train_x[shuffle[t]]
+      g_transform_pool:addjob(
+	 function()
+	    local xy = transformer(input, false, settings.patches)
+	    return xy
+	 end,
+	 function(xy)
+	    for i = 1, #xy do
+	       if c <= x:size(1) then
+		  x[c]:copy(xy[i][1])
+		  y[c]:copy(xy[i][2])
+		  c = c + 1
+	       else
+		  break
+	       end
+	    end
 	 end
 	 end
+      )
+      if t % 50 == 0 then
+	 xlua.progress(t, #train_x)
+	 g_transform_pool:synchronize()
+	 collectgarbage()
       end
       end
       if c > x:size(1) then
       if c > x:size(1) then
 	 break
 	 break
       end
       end
-      if t % 50 == 0 then
-	 collectgarbage()
-      end
    end
    end
+   g_transform_pool:synchronize()
    xlua.progress(#train_x, #train_x)
    xlua.progress(#train_x, #train_x)
+   torch.setnumthreads(nthread) -- revert
 end
 end
 local function get_oracle_data(x, y, instance_loss, k, samples)
 local function get_oracle_data(x, y, instance_loss, k, samples)
    local index = torch.LongTensor(instance_loss:size(1))
    local index = torch.LongTensor(instance_loss:size(1))
@@ -262,6 +312,7 @@ local function get_oracle_data(x, y, instance_loss, k, samples)
 end
 end
 
 
 local function remove_small_image(x)
 local function remove_small_image(x)
+   local compression = require 'compression'
    local new_x = {}
    local new_x = {}
    for i = 1, #x do
    for i = 1, #x do
       local xe, meta, x_s
       local xe, meta, x_s
@@ -304,9 +355,8 @@ local function train()
    dir.makepath(settings.model_dir)
    dir.makepath(settings.model_dir)
 
 
    local offset = reconstruct.offset_size(model)
    local offset = reconstruct.offset_size(model)
-   local pairwise_func = function(x, is_validation, n)
-      return transformer(model, x, is_validation, n, offset)
-   end
+   transform_pool_init(reconstruct.has_resize(model), offset)
+
    local criterion = create_criterion(model)
    local criterion = create_criterion(model)
    local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
    local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
    local x = remove_small_image(torch.load(settings.images))
    local x = remove_small_image(torch.load(settings.images))
@@ -324,7 +374,7 @@ local function train()
    end
    end
    local best_score = 1000.0
    local best_score = 1000.0
    print("# make validation-set")
    print("# make validation-set")
-   local valid_xy = make_validation_set(valid_x, pairwise_func,
+   local valid_xy = make_validation_set(valid_x, 
 					settings.validation_crops,
 					settings.validation_crops,
 					settings.patches)
 					settings.patches)
    valid_x = nil
    valid_x = nil
@@ -358,7 +408,7 @@ local function train()
 	 if oracle_n > 0 then
 	 if oracle_n > 0 then
 	    local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n)
 	    local oracle_x, oracle_y = get_oracle_data(x, y, instance_loss, oracle_k, oracle_n)
 	    resampling(x:narrow(1, oracle_x:size(1) + 1, x:size(1)-oracle_x:size(1)),
 	    resampling(x:narrow(1, oracle_x:size(1) + 1, x:size(1)-oracle_x:size(1)),
-		       y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x, pairwise_func)
+		       y:narrow(1, oracle_x:size(1) + 1, x:size(1) - oracle_x:size(1)), train_x)
 	    x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x)
 	    x:narrow(1, 1, oracle_x:size(1)):copy(oracle_x)
 	    y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y)
 	    y:narrow(1, 1, oracle_y:size(1)):copy(oracle_y)
 
 
@@ -374,7 +424,7 @@ local function train()
 			     min = 0,
 			     min = 0,
 			     max = 1}))
 			     max = 1}))
 	 else
 	 else
-	    resampling(x, y, train_x, pairwise_func)
+	    resampling(x, y, train_x)
 	 end
 	 end
       else
       else
 	 resampling(x, y, train_x, pairwise_func)
 	 resampling(x, y, train_x, pairwise_func)