Przeglądaj źródła

Support Linear module

nagadomi 7 lat temu
rodzic
commit
929c7f85a9
1 zmienionych plików z 17 dodań i 7 usunięć
  1. 17 7
      tools/export_model.lua

+ 17 - 7
tools/export_model.lua

@@ -37,19 +37,29 @@ local function get_bias(mod)
    end
 end
 local function export_weight(jmodules, seq)
-   local targets = {"nn.SpatialConvolutionMM",
-		    "cudnn.SpatialConvolution",
-		    "cudnn.SpatialDilatedConvolution",
-		    "nn.SpatialFullConvolution",
-		    "nn.SpatialDilatedConvolution",
-		    "cudnn.SpatialFullConvolution"
+   local convolutions = {"nn.SpatialConvolutionMM",
+			 "cudnn.SpatialConvolution",
+			 "cudnn.SpatialDilatedConvolution",
+			 "nn.SpatialFullConvolution",
+			 "nn.SpatialDilatedConvolution",
+			 "cudnn.SpatialFullConvolution"
    }
    for k = 1, #seq.modules do
       local mod = seq.modules[k]
       local name = torch.typename(mod)
       if name == "nn.Sequential" or name == "nn.ConcatTable" then
 	 export_weight(jmodules, mod)
-      elseif includes(name, targets) then
+      elseif name == "nn.Linear" then
+	 local weight = torch.totable(mod.weight:float())
+	 local jmod = {
+	    class_name = name,
+	    nInputPlane = mod.weight:size(2),
+	    nOutputPlane = mod.weight:size(1),
+	    bias = torch.totable(get_bias(mod)),
+	    weight = weight
+	 }
+	 table.insert(jmodules, jmod)
+      elseif includes(name, convolutions) then
 	 local weight = mod.weight:float()
 	 if name:match("FullConvolution") then
 	    weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))