image_loader.lua 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. local gm = require 'graphicsmagick'
  2. local ffi = require 'ffi'
  3. require 'pl'
  4. local image_loader = {}
  5. function image_loader.decode_float(blob)
  6. local im, alpha = image_loader.decode_byte(blob)
  7. if im then
  8. im = im:float():div(255)
  9. end
  10. return im, alpha
  11. end
  12. function image_loader.encode_png(rgb, alpha)
  13. if rgb:type() == "torch.ByteTensor" then
  14. error("expect FloatTensor")
  15. end
  16. if alpha then
  17. if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
  18. alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
  19. end
  20. local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
  21. rgba[1]:copy(rgb[1])
  22. rgba[2]:copy(rgb[2])
  23. rgba[3]:copy(rgb[3])
  24. rgba[4]:copy(alpha)
  25. local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
  26. im:format("png")
  27. return im:toBlob()
  28. else
  29. local im = gm.Image(rgb, "RGB", "DHW")
  30. im:format("png")
  31. return im:toBlob()
  32. end
  33. end
  34. function image_loader.save_png(filename, rgb, alpha)
  35. local blob, len = image_loader.encode_png(rgb, alpha)
  36. local fp = io.open(filename, "wb")
  37. fp:write(ffi.string(blob, len))
  38. fp:close()
  39. return true
  40. end
  41. function image_loader.decode_byte(blob)
  42. local load_image = function()
  43. local im = gm.Image()
  44. local alpha = nil
  45. im:fromBlob(blob, #blob)
  46. -- FIXME: How to detect that a image has an alpha channel?
  47. if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
  48. -- split alpha channel
  49. im = im:toTensor('float', 'RGBA', 'DHW')
  50. local sum_alpha = (im[4] - 1):sum()
  51. if sum_alpha > 0 or sum_alpha < 0 then
  52. alpha = im[4]:reshape(1, im:size(2), im:size(3))
  53. end
  54. local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
  55. new_im[1]:copy(im[1])
  56. new_im[2]:copy(im[2])
  57. new_im[3]:copy(im[3])
  58. im = new_im:mul(255):byte()
  59. else
  60. im = im:toTensor('byte', 'RGB', 'DHW')
  61. end
  62. return {im, alpha}
  63. end
  64. local state, ret = pcall(load_image)
  65. if state then
  66. return ret[1], ret[2]
  67. else
  68. return nil
  69. end
  70. end
  71. function image_loader.load_float(file)
  72. local fp = io.open(file, "rb")
  73. if not fp then
  74. error(file .. ": failed to load image")
  75. end
  76. local buff = fp:read("*a")
  77. fp:close()
  78. return image_loader.decode_float(buff)
  79. end
  80. function image_loader.load_byte(file)
  81. local fp = io.open(file, "rb")
  82. if not fp then
  83. error(file .. ": failed to load image")
  84. end
  85. local buff = fp:read("*a")
  86. fp:close()
  87. return image_loader.decode_byte(buff)
  88. end
  89. local function test()
  90. require 'image'
  91. local img
  92. img = image_loader.load_float("./a.jpg")
  93. if img then
  94. print(img:min())
  95. print(img:max())
  96. image.display(img)
  97. end
  98. img = image_loader.load_float("./b.png")
  99. if img then
  100. image.display(img)
  101. end
  102. end
  103. --test()
  104. return image_loader