|
@@ -22,7 +22,6 @@ local function includes(s, a)
|
|
|
end
|
|
|
return false
|
|
|
end
|
|
|
-
|
|
|
local function get_bias(mod)
|
|
|
if mod.bias then
|
|
|
return mod.bias:float()
|
|
@@ -31,20 +30,18 @@ local function get_bias(mod)
|
|
|
return torch.FloatTensor(mod.nOutputPlane):zero()
|
|
|
end
|
|
|
end
|
|
|
-local function export(model, output)
|
|
|
+local function export_weight(jmodules, seq)
|
|
|
local targets = {"nn.SpatialConvolutionMM",
|
|
|
"cudnn.SpatialConvolution",
|
|
|
"nn.SpatialFullConvolution",
|
|
|
"cudnn.SpatialFullConvolution"
|
|
|
}
|
|
|
- local jmodules = {}
|
|
|
- local model_config = meta_data(model)
|
|
|
- local first_layer = true
|
|
|
-
|
|
|
- for k = 1, #model.modules do
|
|
|
- local mod = model.modules[k]
|
|
|
+ for k = 1, #seq.modules do
|
|
|
+ local mod = seq.modules[k]
|
|
|
local name = torch.typename(mod)
|
|
|
- if includes(name, targets) then
|
|
|
+ if name == "nn.Sequential" or name == "nn.ConcatTable" then
|
|
|
+ export_weight(jmodules, mod)
|
|
|
+ elseif includes(name, targets) then
|
|
|
local weight = mod.weight:float()
|
|
|
if name:match("FullConvolution") then
|
|
|
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
|
|
@@ -71,6 +68,14 @@ local function export(model, output)
|
|
|
table.insert(jmodules, jmod)
|
|
|
end
|
|
|
end
|
|
|
+end
|
|
|
+local function export(model, output)
|
|
|
+ local jmodules = {}
|
|
|
+ local model_config = meta_data(model)
|
|
|
+ local first_layer = true
|
|
|
+
|
|
|
+ export_weight(jmodules, model)
|
|
|
+
|
|
|
local fp = io.open(output, "w")
|
|
|
if not fp then
|
|
|
error("IO Error: " .. output)
|