cleanup_model.lua 741 B

12345678910111213141516171819202122232425
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  4. require 'w2nn'
  5. torch.setdefaulttensortype("torch.FloatTensor")
  6. local cmd = torch.CmdLine()
  7. cmd:text()
  8. cmd:text("cleanup model")
  9. cmd:text("Options:")
  10. cmd:option("-model", "./model.t7", 'path of model file')
  11. cmd:option("-iformat", "binary", 'input format')
  12. cmd:option("-oformat", "binary", 'output format')
  13. local opt = cmd:parse(arg)
  14. local model = torch.load(opt.model, opt.iformat)
  15. if model then
  16. w2nn.cleanup_model(model)
  17. model:cuda()
  18. model:evaluate()
  19. torch.save(opt.model, model, opt.oformat)
  20. else
  21. error("model not found")
  22. end