Bladeren bron

Add -gpu option

nagadomi 8 jaren geleden
bovenliggende
commit
50fd999c38
2 gewijzigde bestanden met toevoegingen van 6 en 0 verwijderingen
  1. 4 0
      lib/settings.lua
  2. 2 0
      waifu2x.lua

+ 4 - 0
lib/settings.lua

@@ -1,6 +1,7 @@
 require 'xlua'
 require 'pl'
 require 'trepl'
+require 'cutorch'
 
 -- global settings
 
@@ -63,6 +64,7 @@ cmd:option("-oracle_drop_rate", 0.5, '')
 cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
 cmd:option("-resume", "", 'resume model file')
 cmd:option("-name", "user", 'model name for user method')
+cmd:option("-gpu", 1, 'Device ID')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then
@@ -152,4 +154,6 @@ end
 settings.images = string.format("%s/images.t7", settings.data_dir)
 settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
 
+cutorch.setDevice(opt.gpu)
+
 return settings

+ 2 - 0
waifu2x.lua

@@ -267,6 +267,7 @@ local function waifu2x()
    cmd:option("-tta_level", 8, 'TTA level (2|4|8). A higher value makes better quality output but slow')
    cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)')
    cmd:option("-q", 0, 'quiet (0|1)')
+   cmd:option("-gpu", 1, 'Device ID')
 
    local opt = cmd:parse(arg)
    if opt.method:len() > 0 then
@@ -292,5 +293,6 @@ local function waifu2x()
    else
       convert_frames(opt)
    end
+   cutorch.setDevice(opt.gpu)
 end
 waifu2x()