Преглед на файлове

Don't use cudnn.benchmark mode when predicting

nagadomi преди 9 години
родител
ревизия
425898a3aa
променени са 3 файла, в които са добавени 15 реда и са изтрити 1 реда
  1. 5 0
      tools/benchmark.lua
  2. 6 1
      waifu2x.lua
  3. 4 0
      web.lua

+ 5 - 0
tools/benchmark.lua

@@ -27,6 +27,11 @@ cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each time
 
 local opt = cmd:parse(arg)
 torch.setdefaulttensortype('torch.FloatTensor')
+if cudnn then
+   cudnn.fastest = true
+   cudnn.benchmark = false
+end
+
 
 local function MSE(x1, x2)
    return (x1 - x2):pow(2):mean()

+ 6 - 1
waifu2x.lua

@@ -112,11 +112,16 @@ local function waifu2x()
    cmd:option("-crop_size", 128, 'patch size per process')
    cmd:option("-resume", 0, "skip existing files (0|1)")
    cmd:option("-thread", -1, "number of CPU threads")
-
+   
    local opt = cmd:parse(arg)
    if opt.thread > 0 then
       torch.setnumthreads(opt.thread)
    end
+   if cudnn then
+      cudnn.fastest = true
+      cudnn.benchmark = false
+   end
+   
    if string.len(opt.l) == 0 then
       convert_image(opt)
    else

+ 4 - 0
web.lua

@@ -25,6 +25,10 @@ torch.setdefaulttensortype('torch.FloatTensor')
 if opt.thread > 0 then
    torch.setnumthreads(opt.thread)
 end
+if cudnn then
+   cudnn.fastest = true
+   cudnn.benchmark = false
+end
 
 local MODEL_DIR = "./models/anime_style_art_rgb"
 local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")