| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 | 
							- -- adapted from https://github.com/marcan/cl-waifu2x
 
- require 'pl'
 
- local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
 
- package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
 
- require 'w2nn'
 
- local cjson = require "cjson"
 
- function meta_data(model)
 
-    local meta = {}
 
-    for k, v in pairs(model) do
 
-       if k:match("w2nn_") then
 
- 	 meta[k:gsub("w2nn_", "")] = v
 
-       end
 
-    end
 
-    return meta
 
- end
 
- function includes(s, a)
 
-    for i = 1, #a do
 
-       if s == a[i] then
 
- 	 return true
 
-       end
 
-    end
 
-    return false
 
- end
 
- function export(model, output)
 
-    local targets = {"nn.SpatialConvolutionMM",
 
- 		    "cudnn.SpatialConvolution",
 
- 		    "nn.SpatialFullConvolution",
 
- 		    "cudnn.SpatialFullConvolution"
 
-    }
 
-    local jmodules = {}
 
-    local model_config = meta_data(model)
 
-    local first_layer = true
 
-    for k = 1, #model.modules do
 
-       local mod = model.modules[k]
 
-       local name = torch.typename(mod)
 
-       if includes(name, targets) then
 
- 	 local weight = mod.weight:float()
 
- 	 if name:match("FullConvolution") then
 
- 	    weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
 
- 	 else
 
- 	    weight = torch.totable(weight:reshape(mod.nOutputPlane, mod.nInputPlane, mod.kH, mod.kW))
 
- 	 end
 
- 	 local jmod = {
 
- 	    class_name = name,
 
- 	    kW = mod.kW,
 
- 	    kH = mod.kH,
 
- 	    dH = mod.dH,
 
- 	    dW = mod.dW,
 
- 	    padW = mod.padW,
 
- 	    padH = mod.padH,
 
- 	    nInputPlane = mod.nInputPlane,
 
- 	    nOutputPlane = mod.nOutputPlane,
 
- 	    bias = torch.totable(mod.bias:float()),
 
- 	    weight = weight
 
- 	 }
 
- 	 if first_layer then
 
- 	    first_layer = false
 
- 	    jmod.model_config = model_config
 
- 	 end
 
- 	 table.insert(jmodules, jmod)
 
-       end
 
-    end
 
-    local fp = io.open(output, "w")
 
-    if not fp then
 
-       error("IO Error: " .. output)
 
-    end
 
-    fp:write(cjson.encode(jmodules))
 
-    fp:close()
 
- end
 
- local cmd = torch.CmdLine()
 
- cmd:text()
 
- cmd:text("waifu2x export model")
 
- cmd:text("Options:")
 
- cmd:option("-i", "input.t7", 'Specify the input torch model')
 
- cmd:option("-o", "output.json", 'Specify the output json file')
 
- cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
 
- local opt = cmd:parse(arg)
 
- if not path.isfile(opt.i) then
 
-    cmd:help()
 
-    os.exit(-1)
 
- end
 
- local model = torch.load(opt.i, opt.iformat)
 
- export(model, opt.o)
 
 
  |