visualize_layer_output.lua 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. require 'pl'
  2. local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
  3. package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
  4. require 'sys'
  5. require 'w2nn'
  6. local iproc = require 'iproc'
  7. local reconstruct = require 'reconstruct'
  8. local image_loader = require 'image_loader'
  9. local CONV_LAYERS = {"nn.SpatialConvolutionMM",
  10. "cudnn.SpatialConvolution",
  11. "nn.SpatialFullConvolution",
  12. "cudnn.SpatialFullConvolution"
  13. }
  14. local ACTIVATION_LAYERS = {"nn.ReLU",
  15. "nn.LeakyReLU",
  16. "w2nn.LeakyReLU",
  17. "cudnn.ReLU",
  18. "nn.SoftMax",
  19. "cudnn.SoftMax"
  20. }
  21. local function includes(s, a)
  22. for i = 1, #a do
  23. if s == a[i] then
  24. return true
  25. end
  26. end
  27. return false
  28. end
  29. local function count_conv_layers(seq)
  30. local count = 0
  31. for k = 1, #seq.modules do
  32. local mod = seq.modules[k]
  33. local name = torch.typename(mod)
  34. if name == "nn.ConcatTable" or includes(name, CONV_LAYERS) then
  35. count = count + 1
  36. end
  37. end
  38. return count
  39. end
  40. local function strip_conv_layers(seq, limit)
  41. local new_seq = nn.Sequential()
  42. local count = 0
  43. for k = 1, #seq.modules do
  44. local mod = seq.modules[k]
  45. local name = torch.typename(mod)
  46. if name == "nn.ConcatTable" or includes(name, CONV_LAYERS) then
  47. new_seq:add(mod)
  48. count = count + 1
  49. if count == limit then
  50. if seq.modules[k+1] ~= nil and
  51. includes(torch.typename(seq.modules[k+1]), ACTIVATION_LAYERS) then
  52. new_seq:add(seq.modules[k+1])
  53. end
  54. return new_seq
  55. end
  56. else
  57. new_seq:add(mod)
  58. end
  59. end
  60. return new_seq
  61. end
  62. local function save_layer_outputs(x, model, out)
  63. local count = count_conv_layers(model)
  64. print("conv layer count", count)
  65. local output_file = path.join(out, string.format("layer-%d.png", 0))
  66. image.save(output_file, x)
  67. print("* save layer output " .. 0 .. ": " .. output_file)
  68. for i = 1, count do
  69. output_file = path.join(out, string.format("layer-%d.png", i))
  70. print("* save layer output " .. i .. ": " .. output_file)
  71. local test_model = strip_conv_layers(model, i)
  72. test_model:cuda()
  73. test_model:evaluate()
  74. local z = test_model:forward(x:reshape(1, x:size(1), x:size(2), x:size(3)):cuda()):float()
  75. z = z:reshape(z:size(2), z:size(3), z:size(4)) -- drop batch dim
  76. z = image.toDisplayTensor({input=z, padding=2})
  77. image.save(output_file, z)
  78. collectgarbage()
  79. end
  80. end
  81. local cmd = torch.CmdLine()
  82. cmd:text()
  83. cmd:text("waifu2x - visualize layer output")
  84. cmd:text("Options:")
  85. cmd:option("-i", "images/miku_small.png", 'path to input image')
  86. cmd:option("-scale", 2, 'scale factor')
  87. cmd:option("-o", "./layer_outputs", 'path to output dir')
  88. cmd:option("-model_dir", "./models/upconv_7/art", 'path to model directory')
  89. cmd:option("-name", "user", 'model name for user method')
  90. cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale|user)')
  91. cmd:option("-noise_level", 1, '(1|2|3)')
  92. cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)')
  93. cmd:option("-gpu", 1, 'Device ID')
  94. local opt = cmd:parse(arg)
  95. cutorch.setDevice(opt.gpu)
  96. opt.force_cudnn = opt.force_cudnn == 1
  97. opt.model_path = path.join(opt.model_dir, string.format("%s_model.t7", opt.name))
  98. local x, meta = image_loader.load_float(opt.i)
  99. if x:size(2) > 256 or x:size(3) > 256 then
  100. error(string.format("input image is too large: %dx%d", x:size(3), x:size(2)))
  101. end
  102. local model = nil
  103. local new_x = nil
  104. if opt.m == "noise" then
  105. local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
  106. model = w2nn.load_model(model_path, opt.force_cudnn)
  107. if not model then
  108. error("Load Error: " .. model_path)
  109. end
  110. elseif opt.m == "scale" then
  111. local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
  112. model = w2nn.load_model(model_path, opt.force_cudnn)
  113. if not model then
  114. error("Load Error: " .. model_path)
  115. end
  116. elseif opt.m == "noise_scale" then
  117. local model_path = path.join(opt.model_dir, ("noise%d_scale%.1fx_model.t7"):format(opt.noise_level, opt.scale))
  118. model = w2nn.load_model(model_path, opt.force_cudnn)
  119. elseif opt.m == "user" then
  120. local model_path = opt.model_path
  121. model = w2nn.load_model(model_path, opt.force_cudnn)
  122. if not model then
  123. error("Load Error: " .. model_path)
  124. end
  125. else
  126. error("undefined method:" .. opt.method)
  127. end
  128. assert(model ~= nil)
  129. assert(x ~= nil)
  130. dir.makepath(opt.o)
  131. save_layer_outputs(x, model, opt.o)