|
@@ -890,18 +890,31 @@ function srcnn.upcunet_v2(backend, ch)
|
|
|
end
|
|
|
local function bench()
|
|
|
local sys = require 'sys'
|
|
|
- cudnn.benchmark = false
|
|
|
+ cudnn.benchmark = true
|
|
|
local model = nil
|
|
|
local arch = {"upconv_7", "upcunet", "upcunet_v2"}
|
|
|
- local backend = "cunn"
|
|
|
+ local backend = "cudnn"
|
|
|
for k = 1, #arch do
|
|
|
model = srcnn[arch[k]](backend, 3):cuda()
|
|
|
- model:training()
|
|
|
+ model:evaluate()
|
|
|
+ local dummy = nil
|
|
|
+ -- warn
|
|
|
+ for i = 1, 20 do
|
|
|
+ local x = torch.Tensor(4, 3, 172, 172):uniform():cuda()
|
|
|
+ model:forward(x)
|
|
|
+ end
|
|
|
t = sys.clock()
|
|
|
- for i = 1, 10 do
|
|
|
- model:forward(torch.Tensor(1, 3, 172, 172):zero():cuda())
|
|
|
+ for i = 1, 20 do
|
|
|
+ local x = torch.Tensor(4, 3, 172, 172):uniform():cuda()
|
|
|
+ local z = model:forward(x)
|
|
|
+ if dummy == nil then
|
|
|
+ dummy = z:clone()
|
|
|
+ else
|
|
|
+ dummy:add(z)
|
|
|
+ end
|
|
|
end
|
|
|
print(arch[k], sys.clock() - t)
|
|
|
+ model:clearState()
|
|
|
end
|
|
|
end
|
|
|
function srcnn.create(model_name, backend, color)
|
|
@@ -935,4 +948,5 @@ model:training()
|
|
|
print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda()))
|
|
|
os.exit()
|
|
|
--]]
|
|
|
+
|
|
|
return srcnn
|