Procházet zdrojové kódy

Add to json formt model; Fix json format;

nagadomi před 7 roky
rodič
revize
0b1a13d9c0
1 změnil soubory, kde provedl 18 přidání a 8 odebrání
  1. 18 8
      tools/export_model.lua

+ 18 - 8
tools/export_model.lua

@@ -5,13 +5,19 @@ package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. packa
 require 'w2nn'
 local cjson = require "cjson"
 
-local function meta_data(model)
+local function meta_data(model, model_path)
    local meta = {}
    for k, v in pairs(model) do
       if k:match("w2nn_") then
 	 meta[k:gsub("w2nn_", "")] = v
       end
    end
+
+   modtime = file.modified_time(model_path)
+   utc_date = Date('utc')
+   utc_date:set(modtime)
+   meta["created_at"] = tostring(utc_date)
+
    return meta
 end
 local function includes(s, a)
@@ -33,7 +39,9 @@ end
 local function export_weight(jmodules, seq)
    local targets = {"nn.SpatialConvolutionMM",
 		    "cudnn.SpatialConvolution",
+		    "cudnn.SpatialDilatedConvolution",
 		    "nn.SpatialFullConvolution",
+		    "nn.SpatialDilatedConvolution",
 		    "cudnn.SpatialFullConvolution"
    }
    for k = 1, #seq.modules do
@@ -56,25 +64,27 @@ local function export_weight(jmodules, seq)
 	    dW = mod.dW,
 	    padW = mod.padW,
 	    padH = mod.padH,
+	    dilationW = mod.dilationW,
+	    dilationH = mod.dilationH,
 	    nInputPlane = mod.nInputPlane,
 	    nOutputPlane = mod.nOutputPlane,
 	    bias = torch.totable(get_bias(mod)),
 	    weight = weight
 	 }
-	 if first_layer then
-	    first_layer = false
-	    jmod.model_config = model_config
-	 end
 	 table.insert(jmodules, jmod)
       end
    end
 end
-local function export(model, output)
+local function export(model, model_path, output)
    local jmodules = {}
-   local model_config = meta_data(model)
+   local model_config = meta_data(model, model_path)
    local first_layer = true
 
+   print(model_config)
+   print(model)
+
    export_weight(jmodules, model)
+   jmodules[1]["model_config"] = model_config
 
    local fp = io.open(output, "w")
    if not fp then
@@ -98,4 +108,4 @@ if not path.isfile(opt.i) then
    os.exit(-1)
 end
 local model = torch.load(opt.i, opt.iformat)
-export(model, opt.o)
+export(model, opt.i, opt.o)