export_model.lua 705 B

1234567891011121314151617181920212223
  1. -- adapted from https://github.com/marcan/cl-waifu2x
  2. require './lib/portable'
  3. require './lib/LeakyReLU'
  4. local cjson = require "cjson"
  5. local model = torch.load(arg[1], "ascii")
  6. local jmodules = {}
  7. local modules = model:findModules("nn.SpatialConvolutionMM")
  8. for i = 1, #modules, 1 do
  9. local module = modules[i]
  10. local jmod = {
  11. kW = module.kW,
  12. kH = module.kH,
  13. nInputPlane = module.nInputPlane,
  14. nOutputPlane = module.nOutputPlane,
  15. bias = torch.totable(module.bias:float()),
  16. weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
  17. }
  18. table.insert(jmodules, jmod)
  19. end
  20. io.write(cjson.encode(jmodules))