|
@@ -18,7 +18,6 @@ local cmd = torch.CmdLine()
|
|
cmd:text()
|
|
cmd:text()
|
|
cmd:text("waifu2x-training")
|
|
cmd:text("waifu2x-training")
|
|
cmd:text("Options:")
|
|
cmd:text("Options:")
|
|
-cmd:option("-gpu", -1, 'GPU Device ID')
|
|
|
|
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
|
|
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
|
|
cmd:option("-data_dir", "./data", 'path to data directory')
|
|
cmd:option("-data_dir", "./data", 'path to data directory')
|
|
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
|
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
|
@@ -74,7 +73,7 @@ cmd:option("-oracle_drop_rate", 0.5, '')
|
|
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
|
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
|
cmd:option("-resume", "", 'resume model file')
|
|
cmd:option("-resume", "", 'resume model file')
|
|
cmd:option("-name", "user", 'model name for user method')
|
|
cmd:option("-name", "user", 'model name for user method')
|
|
-cmd:option("-gpu", 1, 'Device ID')
|
|
|
|
|
|
+cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
|
|
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
|
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
|
cmd:option("-update_criterion", "mse", 'mse|loss')
|
|
cmd:option("-update_criterion", "mse", 'mse|loss')
|
|
|
|
|
|
@@ -168,6 +167,16 @@ end
|
|
settings.images = string.format("%s/images.t7", settings.data_dir)
|
|
settings.images = string.format("%s/images.t7", settings.data_dir)
|
|
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
|
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
|
|
|
|
|
-cutorch.setDevice(opt.gpu)
|
|
|
|
|
|
+if settings.gpu:len() > 0 then
|
|
|
|
+ local gpus = {}
|
|
|
|
+ local gpu_string = utils.split(settings.gpu, ",")
|
|
|
|
+ for i = 1, #gpu_string do
|
|
|
|
+ table.insert(gpus, tonumber(gpu_string[i]))
|
|
|
|
+ end
|
|
|
|
+ settings.gpu = gpus
|
|
|
|
+else
|
|
|
|
+ settings.gpu = {1}
|
|
|
|
+end
|
|
|
|
+cutorch.setDevice(settings.gpu[1])
|
|
|
|
|
|
return settings
|
|
return settings
|