|
@@ -37,19 +37,29 @@ local function get_bias(mod)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
local function export_weight(jmodules, seq)
|
|
local function export_weight(jmodules, seq)
|
|
- local targets = {"nn.SpatialConvolutionMM",
|
|
|
|
- "cudnn.SpatialConvolution",
|
|
|
|
- "cudnn.SpatialDilatedConvolution",
|
|
|
|
- "nn.SpatialFullConvolution",
|
|
|
|
- "nn.SpatialDilatedConvolution",
|
|
|
|
- "cudnn.SpatialFullConvolution"
|
|
|
|
|
|
+ local convolutions = {"nn.SpatialConvolutionMM",
|
|
|
|
+ "cudnn.SpatialConvolution",
|
|
|
|
+ "cudnn.SpatialDilatedConvolution",
|
|
|
|
+ "nn.SpatialFullConvolution",
|
|
|
|
+ "nn.SpatialDilatedConvolution",
|
|
|
|
+ "cudnn.SpatialFullConvolution"
|
|
}
|
|
}
|
|
for k = 1, #seq.modules do
|
|
for k = 1, #seq.modules do
|
|
local mod = seq.modules[k]
|
|
local mod = seq.modules[k]
|
|
local name = torch.typename(mod)
|
|
local name = torch.typename(mod)
|
|
if name == "nn.Sequential" or name == "nn.ConcatTable" then
|
|
if name == "nn.Sequential" or name == "nn.ConcatTable" then
|
|
export_weight(jmodules, mod)
|
|
export_weight(jmodules, mod)
|
|
- elseif includes(name, targets) then
|
|
|
|
|
|
+ elseif name == "nn.Linear" then
|
|
|
|
+ local weight = torch.totable(mod.weight:float())
|
|
|
|
+ local jmod = {
|
|
|
|
+ class_name = name,
|
|
|
|
+ nInputPlane = mod.weight:size(2),
|
|
|
|
+ nOutputPlane = mod.weight:size(1),
|
|
|
|
+ bias = torch.totable(get_bias(mod)),
|
|
|
|
+ weight = weight
|
|
|
|
+ }
|
|
|
|
+ table.insert(jmodules, jmod)
|
|
|
|
+ elseif includes(name, convolutions) then
|
|
local weight = mod.weight:float()
|
|
local weight = mod.weight:float()
|
|
if name:match("FullConvolution") then
|
|
if name:match("FullConvolution") then
|
|
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
|
|
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
|