export_model.lua 852 B

123456789101112131415161718192021222324
  1. -- adapted from https://github.com/marcan/cl-waifu2x
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  4. require 'w2nn'
  5. local cjson = require "cjson"
  6. local model = torch.load(arg[1], "ascii")
  7. local jmodules = {}
  8. local modules = model:findModules("nn.SpatialConvolutionMM")
  9. for i = 1, #modules, 1 do
  10. local module = modules[i]
  11. local jmod = {
  12. kW = module.kW,
  13. kH = module.kH,
  14. nInputPlane = module.nInputPlane,
  15. nOutputPlane = module.nOutputPlane,
  16. bias = torch.totable(module.bias:float()),
  17. weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
  18. }
  19. table.insert(jmodules, jmod)
  20. end
  21. io.write(cjson.encode(jmodules))