waifu2x.lua 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. require 'cudnn'
  2. require 'sys'
  3. require 'pl'
  4. require './lib/LeakyReLU'
  5. local iproc = require './lib/iproc'
  6. local reconstract = require './lib/reconstract'
  7. local image_loader = require './lib/image_loader'
  8. local BLOCK_OFFSET = 7
  9. torch.setdefaulttensortype('torch.FloatTensor')
  10. local function waifu2x()
  11. local cmd = torch.CmdLine()
  12. cmd:text()
  13. cmd:text("waifu2x")
  14. cmd:text("Options:")
  15. cmd:option("-i", "images/miku_small.png", 'path of input image')
  16. cmd:option("-o", "(auto)", 'path of output')
  17. cmd:option("-model_dir", "./models", 'model directory')
  18. cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
  19. cmd:option("-noise_level", 1, '(1|2)')
  20. cmd:option("-crop_size", 128, 'crop size')
  21. local opt = cmd:parse(arg)
  22. if opt.o == "(auto)" then
  23. local name = path.basename(opt.i)
  24. local e = path.extension(name)
  25. local base = name:sub(0, name:len() - e:len())
  26. opt.o = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
  27. end
  28. local x = image_loader.load_float(opt.i)
  29. local new_x = nil
  30. local t = sys.clock()
  31. if opt.m == "noise" then
  32. local model = torch.load(path.join(opt.model_dir,
  33. ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
  34. model:evaluate()
  35. new_x = reconstract(model, x, BLOCK_OFFSET)
  36. elseif opt.m == "scale" then
  37. local model = torch.load(path.join(opt.model_dir, "scale2.0x_model.t7"), "ascii")
  38. model:evaluate()
  39. x = iproc.scale(x, x:size(3) * 2.0, x:size(2) * 2.0)
  40. new_x = reconstract(model, x, BLOCK_OFFSET)
  41. elseif opt.m == "noise_scale" then
  42. local noise_model = torch.load(path.join(opt.model_dir,
  43. ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
  44. local scale_model = torch.load(path.join(opt.model_dir, "scale2.0x_model.t7"), "ascii")
  45. noise_model:evaluate()
  46. scale_model:evaluate()
  47. x = reconstract(noise_model, x, BLOCK_OFFSET)
  48. x = iproc.scale(x, x:size(3) * 2.0, x:size(2) * 2.0)
  49. new_x = reconstract(scale_model, x, BLOCK_OFFSET)
  50. else
  51. error("undefined method:" .. opt.method)
  52. end
  53. image.save(opt.o, new_x)
  54. print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
  55. end
  56. waifu2x()