|
@@ -5,13 +5,19 @@ package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. packa
|
|
|
require 'w2nn'
|
|
|
local cjson = require "cjson"
|
|
|
|
|
|
-local function meta_data(model)
|
|
|
+local function meta_data(model, model_path)
|
|
|
local meta = {}
|
|
|
for k, v in pairs(model) do
|
|
|
if k:match("w2nn_") then
|
|
|
meta[k:gsub("w2nn_", "")] = v
|
|
|
end
|
|
|
end
|
|
|
+
|
|
|
+ modtime = file.modified_time(model_path)
|
|
|
+ utc_date = Date('utc')
|
|
|
+ utc_date:set(modtime)
|
|
|
+ meta["created_at"] = tostring(utc_date)
|
|
|
+
|
|
|
return meta
|
|
|
end
|
|
|
local function includes(s, a)
|
|
@@ -33,7 +39,9 @@ end
|
|
|
local function export_weight(jmodules, seq)
|
|
|
local targets = {"nn.SpatialConvolutionMM",
|
|
|
"cudnn.SpatialConvolution",
|
|
|
+ "cudnn.SpatialDilatedConvolution",
|
|
|
"nn.SpatialFullConvolution",
|
|
|
+ "nn.SpatialDilatedConvolution",
|
|
|
"cudnn.SpatialFullConvolution"
|
|
|
}
|
|
|
for k = 1, #seq.modules do
|
|
@@ -56,25 +64,27 @@ local function export_weight(jmodules, seq)
|
|
|
dW = mod.dW,
|
|
|
padW = mod.padW,
|
|
|
padH = mod.padH,
|
|
|
+ dilationW = mod.dilationW,
|
|
|
+ dilationH = mod.dilationH,
|
|
|
nInputPlane = mod.nInputPlane,
|
|
|
nOutputPlane = mod.nOutputPlane,
|
|
|
bias = torch.totable(get_bias(mod)),
|
|
|
weight = weight
|
|
|
}
|
|
|
- if first_layer then
|
|
|
- first_layer = false
|
|
|
- jmod.model_config = model_config
|
|
|
- end
|
|
|
table.insert(jmodules, jmod)
|
|
|
end
|
|
|
end
|
|
|
end
|
|
|
-local function export(model, output)
|
|
|
+local function export(model, model_path, output)
|
|
|
local jmodules = {}
|
|
|
- local model_config = meta_data(model)
|
|
|
+ local model_config = meta_data(model, model_path)
|
|
|
local first_layer = true
|
|
|
|
|
|
+ print(model_config)
|
|
|
+ print(model)
|
|
|
+
|
|
|
export_weight(jmodules, model)
|
|
|
+ jmodules[1]["model_config"] = model_config
|
|
|
|
|
|
local fp = io.open(output, "w")
|
|
|
if not fp then
|
|
@@ -98,4 +108,4 @@ if not path.isfile(opt.i) then
|
|
|
os.exit(-1)
|
|
|
end
|
|
|
local model = torch.load(opt.i, opt.iformat)
|
|
|
-export(model, opt.o)
|
|
|
+export(model, opt.i, opt.o)
|