소스 검색

Make the performance benchmark practical

nagadomi 6 년 전
부모
커밋
17b8de2d36
1개의 변경된 파일19개의 추가작업 그리고 5개의 파일을 삭제
  1. 19 5
      lib/srcnn.lua

+ 19 - 5
lib/srcnn.lua

@@ -890,18 +890,31 @@ function srcnn.upcunet_v2(backend, ch)
 end
 local function bench()
    local sys = require 'sys'
-   cudnn.benchmark = false
+   cudnn.benchmark = true
    local model = nil
    local arch = {"upconv_7", "upcunet", "upcunet_v2"}
-   local backend = "cunn"
+   local backend = "cudnn"
    for k = 1, #arch do
       model = srcnn[arch[k]](backend, 3):cuda()
-      model:training()
+      model:evaluate()
+      local dummy = nil
+      -- warn
+      for i = 1, 20 do
+	 local x = torch.Tensor(4, 3, 172, 172):uniform():cuda()
+	 model:forward(x)
+      end
       t = sys.clock()
-      for i = 1, 10 do
-	 model:forward(torch.Tensor(1, 3, 172, 172):zero():cuda())
+      for i = 1, 20 do
+	 local x = torch.Tensor(4, 3, 172, 172):uniform():cuda()
+	 local z = model:forward(x)
+	 if dummy == nil then
+	    dummy = z:clone()
+	 else
+	    dummy:add(z)
+	 end
       end
       print(arch[k], sys.clock() - t)
+      model:clearState()
    end
 end
 function srcnn.create(model_name, backend, color)
@@ -935,4 +948,5 @@ model:training()
 print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda()))
 os.exit()
 --]]
+
 return srcnn