export_model.lua 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. -- adapted from https://github.com/marcan/cl-waifu2x
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  4. require 'w2nn'
  5. local cjson = require "cjson"
  6. function export(model, output)
  7. local jmodules = {}
  8. local modules = model:findModules("nn.SpatialConvolutionMM")
  9. if #modules == 0 then
  10. -- cudnn model
  11. modules = model:findModules("cudnn.SpatialConvolution")
  12. end
  13. for i = 1, #modules, 1 do
  14. local module = modules[i]
  15. local jmod = {
  16. kW = module.kW,
  17. kH = module.kH,
  18. nInputPlane = module.nInputPlane,
  19. nOutputPlane = module.nOutputPlane,
  20. bias = torch.totable(module.bias:float()),
  21. weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
  22. }
  23. table.insert(jmodules, jmod)
  24. end
  25. local fp = io.open(output, "w")
  26. if not fp then
  27. error("IO Error: " .. output)
  28. end
  29. fp:write(cjson.encode(jmodules))
  30. fp:close()
  31. end
  32. local cmd = torch.CmdLine()
  33. cmd:text()
  34. cmd:text("waifu2x export model")
  35. cmd:text("Options:")
  36. cmd:option("-i", "input.t7", 'Specify the input torch model')
  37. cmd:option("-o", "output.json", 'Specify the output json file')
  38. cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
  39. local opt = cmd:parse(arg)
  40. if not path.isfile(opt.i) then
  41. cmd:help()
  42. os.exit(-1)
  43. end
  44. local model = torch.load(opt.i, opt.iformat)
  45. export(model, opt.o)