|
@@ -5,7 +5,7 @@ package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. packa
|
|
|
require 'w2nn'
|
|
|
local cjson = require "cjson"
|
|
|
|
|
|
-function meta_data(model)
|
|
|
+local function meta_data(model)
|
|
|
local meta = {}
|
|
|
for k, v in pairs(model) do
|
|
|
if k:match("w2nn_") then
|
|
@@ -14,7 +14,7 @@ function meta_data(model)
|
|
|
end
|
|
|
return meta
|
|
|
end
|
|
|
-function includes(s, a)
|
|
|
+local function includes(s, a)
|
|
|
for i = 1, #a do
|
|
|
if s == a[i] then
|
|
|
return true
|
|
@@ -22,7 +22,16 @@ function includes(s, a)
|
|
|
end
|
|
|
return false
|
|
|
end
|
|
|
-function export(model, output)
|
|
|
+
|
|
|
+local function get_bias(mod)
|
|
|
+ if mod.bias then
|
|
|
+ return mod.bias:float()
|
|
|
+ else
|
|
|
+ -- no bias
|
|
|
+ return torch.FloatTensor(mod.nOutputPlane):zero()
|
|
|
+ end
|
|
|
+end
|
|
|
+local function export(model, output)
|
|
|
local targets = {"nn.SpatialConvolutionMM",
|
|
|
"cudnn.SpatialConvolution",
|
|
|
"nn.SpatialFullConvolution",
|
|
@@ -52,7 +61,7 @@ function export(model, output)
|
|
|
padH = mod.padH,
|
|
|
nInputPlane = mod.nInputPlane,
|
|
|
nOutputPlane = mod.nOutputPlane,
|
|
|
- bias = torch.totable(mod.bias:float()),
|
|
|
+ bias = torch.totable(get_bias(mod)),
|
|
|
weight = weight
|
|
|
}
|
|
|
if first_layer then
|