image_loader.lua 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. local gm = require 'graphicsmagick'
  2. require 'pl'
  3. local image_loader = {}
  4. function image_loader.decode_float(blob)
  5. local im = image_loader.decode_byte(blob)
  6. if im then
  7. im = im:float():div(255)
  8. end
  9. return im
  10. end
  11. function image_loader.encode_png(tensor)
  12. local im = gm.Image(tensor, "RGB", "DHW")
  13. im:format("png")
  14. return im:toBlob()
  15. end
  16. function image_loader.decode_byte(blob)
  17. local load_image = function()
  18. local im = gm.Image()
  19. im:fromBlob(blob, #blob)
  20. -- FIXME: How to detect that a image has an alpha channel?
  21. if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
  22. -- merge alpha channel
  23. im = im:toTensor('float', 'RGBA', 'DHW')
  24. local w2 = im[4]
  25. local w1 = im[4] * -1 + 1
  26. local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
  27. -- apply the white background
  28. new_im[1]:copy(im[1]):cmul(w2):add(w1)
  29. new_im[2]:copy(im[2]):cmul(w2):add(w1)
  30. new_im[3]:copy(im[3]):cmul(w2):add(w1)
  31. im = new_im:mul(255):byte()
  32. else
  33. im = im:toTensor('byte', 'RGB', 'DHW')
  34. end
  35. return im
  36. end
  37. local state, ret = pcall(load_image)
  38. if state then
  39. return ret
  40. else
  41. return nil
  42. end
  43. end
  44. function image_loader.load_float(file)
  45. local fp = io.open(file, "rb")
  46. local buff = fp:read("*a")
  47. fp:close()
  48. return image_loader.decode_float(buff)
  49. end
  50. function image_loader.load_byte(file)
  51. local fp = io.open(file, "rb")
  52. local buff = fp:read("*a")
  53. fp:close()
  54. return image_loader.decode_byte(buff)
  55. end
  56. local function test()
  57. require 'image'
  58. local img
  59. img = image_loader.load_float("./a.jpg")
  60. if img then
  61. print(img:min())
  62. print(img:max())
  63. image.display(img)
  64. end
  65. img = image_loader.load_float("./b.png")
  66. if img then
  67. image.display(img)
  68. end
  69. end
  70. --test()
  71. return image_loader