cudnn2cunn.lua 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  2. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  3. require 'os'
  4. require 'pl'
  5. require 'torch'
  6. require 'cutorch'
  7. require 'cunn'
  8. require 'cudnn'
  9. require 'w2nn'
  10. local srcnn = require 'srcnn'
  11. local function cudnn2cunn(cudnn_model)
  12. local cunn_model = srcnn.waifu2x_cunn(srcnn.channels(cudnn_model))
  13. local weight_from = cudnn_model:findModules("cudnn.SpatialConvolution")
  14. local weight_to = cunn_model:findModules("nn.SpatialConvolutionMM")
  15. assert(#weight_from == #weight_to)
  16. for i = 1, #weight_from do
  17. local from = weight_from[i]
  18. local to = weight_to[i]
  19. to.weight:copy(from.weight)
  20. to.bias:copy(from.bias)
  21. end
  22. cunn_model:cuda()
  23. cunn_model:evaluate()
  24. return cunn_model
  25. end
  26. local cmd = torch.CmdLine()
  27. cmd:text()
  28. cmd:text("waifu2x cudnn model to cunn model converter")
  29. cmd:text("Options:")
  30. cmd:option("-i", "", 'Specify the input cunn model')
  31. cmd:option("-o", "", 'Specify the output cudnn model')
  32. cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
  33. cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
  34. local opt = cmd:parse(arg)
  35. if not path.isfile(opt.i) then
  36. cmd:help()
  37. os.exit(-1)
  38. end
  39. local cudnn_model = torch.load(opt.i, opt.iformat)
  40. local cunn_model = cudnn2cunn(cudnn_model)
  41. torch.save(opt.o, cunn_model, opt.oformat)