cudnn2cunn.lua 999 B

12345678910111213141516171819202122232425262728293031323334
  1. require 'cunn'
  2. require 'cudnn'
  3. require 'cutorch'
  4. require './lib/LeakyReLU'
  5. local srcnn = require 'lib/srcnn'
  6. local function cudnn2cunn(cudnn_model)
  7. local cunn_model = srcnn.waifu2x("y")
  8. local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
  9. local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")
  10. for i = 1, #from_seq do
  11. local from = from_seq[i]
  12. local to = to_seq[i]
  13. to.weight:copy(from.weight)
  14. to.bias:copy(from.bias)
  15. end
  16. cunn_model:cuda()
  17. cunn_model:evaluate()
  18. return cunn_model
  19. end
  20. local cmd = torch.CmdLine()
  21. cmd:text()
  22. cmd:text("convert cudnn model to cunn model ")
  23. cmd:text("Options:")
  24. cmd:option("-model", "./model.t7", 'path of cudnn model file')
  25. cmd:option("-iformat", "ascii", 'input format')
  26. cmd:option("-oformat", "ascii", 'output format')
  27. local opt = cmd:parse(arg)
  28. local cudnn_model = torch.load(opt.model, opt.iformat)
  29. local cunn_model = cudnn2cunn(cudnn_model)
  30. torch.save(opt.model, cunn_model, opt.oformat)