image_loader.lua 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. local gm = require 'graphicsmagick'
  2. local ffi = require 'ffi'
  3. require 'pl'
  4. local image_loader = {}
  5. function image_loader.decode_float(blob)
  6. local im, alpha = image_loader.decode_byte(blob)
  7. if im then
  8. im = im:float():div(255)
  9. end
  10. return im, alpha, blob
  11. end
  12. function image_loader.encode_png(rgb, alpha, depth)
  13. depth = depth or 8
  14. if rgb:type() == "torch.ByteTensor" then
  15. rgb = rgb:float():div(255)
  16. end
  17. if alpha then
  18. if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
  19. alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
  20. end
  21. local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
  22. rgba[1]:copy(rgb[1])
  23. rgba[2]:copy(rgb[2])
  24. rgba[3]:copy(rgb[3])
  25. rgba[4]:copy(alpha)
  26. local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
  27. im:format("png")
  28. return im:depth(depth):toBlob(9)
  29. else
  30. local im = gm.Image(rgb, "RGB", "DHW")
  31. im:format("png")
  32. return im:depth(depth):toBlob(9)
  33. end
  34. end
  35. function image_loader.save_png(filename, rgb, alpha, depth)
  36. depth = depth or 8
  37. local blob, len = image_loader.encode_png(rgb, alpha, depth)
  38. local fp = io.open(filename, "wb")
  39. if not fp then
  40. error("IO error: " .. filename)
  41. end
  42. fp:write(ffi.string(blob, len))
  43. fp:close()
  44. return true
  45. end
  46. function image_loader.decode_byte(blob)
  47. local load_image = function()
  48. local im = gm.Image()
  49. local alpha = nil
  50. local gamma_lcd = 0.454545
  51. im:fromBlob(blob, #blob)
  52. if im:colorspace() == "CMYK" then
  53. im:colorspace("RGB")
  54. end
  55. local gamma = math.floor(im:gamma() * 1000000) / 1000000
  56. if gamma ~= 0 and gamma ~= gamma_lcd then
  57. im:gammaCorrection(gamma / gamma_lcd)
  58. end
  59. -- FIXME: How to detect that a image has an alpha channel?
  60. if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
  61. -- split alpha channel
  62. im = im:toTensor('float', 'RGBA', 'DHW')
  63. local sum_alpha = (im[4] - 1.0):sum()
  64. if sum_alpha < 0 then
  65. alpha = im[4]:reshape(1, im:size(2), im:size(3))
  66. end
  67. local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
  68. new_im[1]:copy(im[1])
  69. new_im[2]:copy(im[2])
  70. new_im[3]:copy(im[3])
  71. im = new_im:mul(255):byte()
  72. else
  73. im = im:toTensor('byte', 'RGB', 'DHW')
  74. end
  75. return {im, alpha, blob}
  76. end
  77. load_image()
  78. local state, ret = pcall(load_image)
  79. if state then
  80. return ret[1], ret[2], ret[3]
  81. else
  82. return nil, nil, nil
  83. end
  84. end
  85. function image_loader.load_float(file)
  86. local fp = io.open(file, "rb")
  87. if not fp then
  88. error(file .. ": failed to load image")
  89. end
  90. local buff = fp:read("*a")
  91. fp:close()
  92. return image_loader.decode_float(buff)
  93. end
  94. function image_loader.load_byte(file)
  95. local fp = io.open(file, "rb")
  96. if not fp then
  97. error(file .. ": failed to load image")
  98. end
  99. local buff = fp:read("*a")
  100. fp:close()
  101. return image_loader.decode_byte(buff)
  102. end
  103. local function test()
  104. require 'image'
  105. local img
  106. img = image_loader.load_float("./a.jpg")
  107. if img then
  108. print(img:min())
  109. print(img:max())
  110. image.display(img)
  111. end
  112. img = image_loader.load_float("./b.png")
  113. if img then
  114. image.display(img)
  115. end
  116. end
  117. --test()
  118. return image_loader