|
@@ -1,6 +1,7 @@
|
|
|
require 'xlua'
|
|
|
require 'pl'
|
|
|
require 'trepl'
|
|
|
+require 'cutorch'
|
|
|
|
|
|
-- global settings
|
|
|
|
|
@@ -63,6 +64,7 @@ cmd:option("-oracle_drop_rate", 0.5, '')
|
|
|
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
|
|
cmd:option("-resume", "", 'resume model file')
|
|
|
cmd:option("-name", "user", 'model name for user method')
|
|
|
+cmd:option("-gpu", 1, 'Device ID')
|
|
|
|
|
|
local function to_bool(settings, name)
|
|
|
if settings[name] == 1 then
|
|
@@ -152,4 +154,6 @@ end
|
|
|
settings.images = string.format("%s/images.t7", settings.data_dir)
|
|
|
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
|
|
|
|
|
+cutorch.setDevice(opt.gpu)
|
|
|
+
|
|
|
return settings
|