switch_aux_output.lua 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  4. require 'os'
  5. require 'w2nn'
  6. local srcnn = require 'srcnn'
  7. local function find_aux(seq)
  8. for k = 1, #seq.modules do
  9. local mod = seq.modules[k]
  10. local name = torch.typename(mod)
  11. if name == "nn.Sequential" or name == "nn.ConcatTable" then
  12. local aux = find_aux(mod)
  13. if aux ~= nil then
  14. return aux
  15. end
  16. elseif name == "w2nn.AuxiliaryLossTable" then
  17. return mod
  18. end
  19. end
  20. return nil
  21. end
  22. local cmd = torch.CmdLine()
  23. cmd:text()
  24. cmd:text("switch the output pass of auxiliary loss")
  25. cmd:text("Options:")
  26. cmd:option("-j", 1, 'Specify the output path index (1|2)')
  27. cmd:option("-i", "", 'Specify the input model')
  28. cmd:option("-o", "", 'Specify the output model')
  29. cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
  30. cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
  31. local opt = cmd:parse(arg)
  32. if not path.isfile(opt.i) then
  33. cmd:help()
  34. os.exit(-1)
  35. end
  36. local model = torch.load(opt.i, opt.iformat)
  37. if model == nil then
  38. print("load error")
  39. os.exit(-1)
  40. end
  41. local aux = find_aux(model)
  42. if aux == nil then
  43. print("AuxiliaryLossTable not found")
  44. else
  45. print(aux)
  46. aux.i = opt.j
  47. torch.save(opt.o, model, opt.oformat)
  48. end