image_loader.lua 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. local gm = require 'graphicsmagick'
  2. require 'pl'
  3. local image_loader = {}
  4. function image_loader.decode_float(blob)
  5. local im = image_loader.decode_byte(blob)
  6. if im then
  7. im = im:float():div(255)
  8. end
  9. return im
  10. end
  11. function image_loader.encode_png(tensor)
  12. local im = gm.Image(tensor, "RGB", "DHW")
  13. im:format("png")
  14. return im:toBlob()
  15. end
  16. function image_loader.decode_byte(blob)
  17. local load_image = function()
  18. local im = gm.Image()
  19. im:fromBlob(blob, #blob)
  20. -- FIXME: How to detect that a image has an alpha channel?
  21. if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
  22. -- merge alpha channel
  23. im = im:toTensor('float', 'RGBA', 'DHW')
  24. local w2 = im[4]
  25. local w1 = im[4] * -1 + 1
  26. local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
  27. -- apply the white background
  28. new_im[1]:copy(im[1]):cmul(w2):add(w1)
  29. new_im[2]:copy(im[2]):cmul(w2):add(w1)
  30. new_im[3]:copy(im[3]):cmul(w2):add(w1)
  31. im = new_im:mul(255):byte()
  32. else
  33. im = im:toTensor('byte', 'RGB', 'DHW')
  34. end
  35. return im
  36. end
  37. local state, ret = pcall(load_image)
  38. if state then
  39. return ret
  40. else
  41. return nil
  42. end
  43. end
  44. function image_loader.load_float(file)
  45. local fp = io.open(file, "rb")
  46. if not fp then
  47. error(file .. ": failed to load image")
  48. end
  49. local buff = fp:read("*a")
  50. fp:close()
  51. return image_loader.decode_float(buff)
  52. end
  53. function image_loader.load_byte(file)
  54. local fp = io.open(file, "rb")
  55. if not fp then
  56. error(file .. ": failed to load image")
  57. end
  58. local buff = fp:read("*a")
  59. fp:close()
  60. return image_loader.decode_byte(buff)
  61. end
  62. local function test()
  63. require 'image'
  64. local img
  65. img = image_loader.load_float("./a.jpg")
  66. if img then
  67. print(img:min())
  68. print(img:max())
  69. image.display(img)
  70. end
  71. img = image_loader.load_float("./b.png")
  72. if img then
  73. image.display(img)
  74. end
  75. end
  76. --test()
  77. return image_loader