Explorar o código

Add support for multi GPU training (data parallel)

train.lua -gpu 1,3,4
When use multi GPU mode, nccl.torch is required.
nagadomi %!s(int64=8) %!d(string=hai) anos
pai
achega
b7e116de54
Modificáronse 3 ficheiros con 59 adicións e 11 borrados
  1. 12 3
      lib/settings.lua
  2. 41 0
      lib/w2nn.lua
  3. 6 8
      train.lua

+ 12 - 3
lib/settings.lua

@@ -18,7 +18,6 @@ local cmd = torch.CmdLine()
 cmd:text()
 cmd:text("waifu2x-training")
 cmd:text("Options:")
-cmd:option("-gpu", -1, 'GPU Device ID')
 cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
 cmd:option("-data_dir", "./data", 'path to data directory')
 cmd:option("-backend", "cunn", '(cunn|cudnn)')
@@ -74,7 +73,7 @@ cmd:option("-oracle_drop_rate", 0.5, '')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
 cmd:option("-resume", "", 'resume model file')
 cmd:option("-name", "user", 'model name for user method')
-cmd:option("-gpu", 1, 'Device ID')
+cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
 cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
 cmd:option("-update_criterion", "mse", 'mse|loss')
 
@@ -168,6 +167,16 @@ end
 settings.images = string.format("%s/images.t7", settings.data_dir)
 settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
 
-cutorch.setDevice(opt.gpu)
+if settings.gpu:len() > 0 then
+   local gpus = {}
+   local gpu_string = utils.split(settings.gpu, ",")
+   for i = 1, #gpu_string do
+      table.insert(gpus, tonumber(gpu_string[i]))
+   end
+   settings.gpu = gpus
+else
+   settings.gpu = {1}
+end
+cutorch.setDevice(settings.gpu[1])
 
 return settings

+ 41 - 0
lib/w2nn.lua

@@ -9,6 +9,40 @@ end
 local function load_cudnn()
    cudnn = require('cudnn')
 end
+local function make_data_parallel_table(model, gpus)
+   if cudnn then
+      local fastest, benchmark = cudnn.fastest, cudnn.benchmark
+      local dpt = nn.DataParallelTable(1, true, true)
+	 :add(model, gpus)
+	 :threads(function()
+	       require 'pl'
+	       local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+	       package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+	       require 'torch'
+	       require 'cunn'
+	       require 'w2nn'
+	       local cudnn = require 'cudnn'
+	       cudnn.fastest, cudnn.benchmark = fastest, benchmark
+		 end)
+      dpt.gradInput = nil
+      model = dpt:cuda()
+   else
+      local dpt = nn.DataParallelTable(1, true, true)
+	    :add(model, gpus)
+	 :threads(function()
+	       require 'pl'
+	       local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
+	       package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
+	       require 'torch'
+	       require 'cunn'
+	       require 'w2nn'
+		 end)
+      dpt.gradInput = nil
+      model = dpt:cuda()
+   end
+   return model
+end
+
 if w2nn then
    return w2nn
 else
@@ -27,6 +61,13 @@ else
       model:cuda():evaluate()
       return model
    end
+   function w2nn.data_parallel(model, gpus)
+      if #gpus > 1 then
+	 return make_data_parallel_table(model, gpus)
+      else
+	 return model
+      end
+   end
    require 'LeakyReLU'
    require 'ClippedWeightedHuberCriterion'
    require 'ClippedMSECriterion'

+ 6 - 8
train.lua

@@ -480,8 +480,9 @@ local function train()
 		       ch, settings.crop_size, settings.crop_size)
    end
    local instance_loss = nil
+   local pmodel = w2nn.data_parallel(model, settings.gpu)
    for epoch = 1, settings.epoch do
-      model:training()
+      pmodel:training()
       print("# " .. epoch)
       if adam_config.learningRate then
 	 print("learning rate: " .. adam_config.learningRate)
@@ -519,13 +520,13 @@ local function train()
       instance_loss = torch.Tensor(x:size(1)):zero()
 
       for i = 1, settings.inner_epoch do
-	 model:training()
-	 local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
+	 pmodel:training()
+	 local train_score, il = minibatch_adam(pmodel, criterion, eval_metric, x, y, adam_config)
 	 instance_loss:copy(il)
 	 print(train_score)
-	 model:evaluate()
+	 pmodel:evaluate()
 	 print("# validation")
-	 local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
+	 local score = validate(pmodel, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
 	 table.insert(hist_train, train_score.loss)
 	 table.insert(hist_valid, score.loss)
 	 if settings.plot then
@@ -593,9 +594,6 @@ local function train()
       end
    end
 end
-if settings.gpu > 0 then
-   cutorch.setDevice(settings.gpu)
-end
 torch.manualSeed(settings.seed)
 cutorch.manualSeed(settings.seed)
 print(settings)