image_loader.lua 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. local gm = require 'graphicsmagick'
  2. local ffi = require 'ffi'
  3. require 'pl'
  4. local image_loader = {}
  5. local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
  6. local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
  7. local background_color = 0.5
  8. function image_loader.decode_float(blob)
  9. local im, alpha = image_loader.decode_byte(blob)
  10. if im then
  11. im = im:float():div(255)
  12. end
  13. return im, alpha, blob
  14. end
  15. function image_loader.encode_png(rgb, alpha, depth)
  16. depth = depth or 8
  17. if rgb:type() == "torch.ByteTensor" then
  18. rgb = rgb:float():div(255)
  19. end
  20. if alpha then
  21. if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
  22. alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
  23. end
  24. local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
  25. rgba[1]:copy(rgb[1])
  26. rgba[2]:copy(rgb[2])
  27. rgba[3]:copy(rgb[3])
  28. rgba[4]:copy(alpha)
  29. if depth < 16 then
  30. rgba:add(clip_eps8)
  31. rgba[torch.lt(rgba, 0.0)] = 0.0
  32. rgba[torch.gt(rgba, 1.0)] = 1.0
  33. else
  34. rgba:add(clip_eps16)
  35. rgba[torch.lt(rgba, 0.0)] = 0.0
  36. rgba[torch.gt(rgba, 1.0)] = 1.0
  37. end
  38. local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
  39. return im:depth(depth):format("PNG"):toBlob(9)
  40. else
  41. if depth < 16 then
  42. rgb = rgb:clone():add(clip_eps8)
  43. rgb[torch.lt(rgb, 0.0)] = 0.0
  44. rgb[torch.gt(rgb, 1.0)] = 1.0
  45. else
  46. rgb = rgb:clone():add(clip_eps16)
  47. rgb[torch.lt(rgb, 0.0)] = 0.0
  48. rgb[torch.gt(rgb, 1.0)] = 1.0
  49. end
  50. local im = gm.Image(rgb, "RGB", "DHW")
  51. return im:depth(depth):format("PNG"):toBlob(9)
  52. end
  53. end
  54. function image_loader.save_png(filename, rgb, alpha, depth)
  55. depth = depth or 8
  56. local blob, len = image_loader.encode_png(rgb, alpha, depth)
  57. local fp = io.open(filename, "wb")
  58. if not fp then
  59. error("IO error: " .. filename)
  60. end
  61. fp:write(ffi.string(blob, len))
  62. fp:close()
  63. return true
  64. end
  65. function image_loader.decode_byte(blob)
  66. local load_image = function()
  67. local im = gm.Image()
  68. local alpha = nil
  69. local gamma_lcd = 0.454545
  70. im:fromBlob(blob, #blob)
  71. if im:colorspace() == "CMYK" then
  72. im:colorspace("RGB")
  73. end
  74. local gamma = math.floor(im:gamma() * 1000000) / 1000000
  75. if gamma ~= 0 and gamma ~= gamma_lcd then
  76. im:gammaCorrection(gamma / gamma_lcd)
  77. end
  78. -- FIXME: How to detect that a image has an alpha channel?
  79. if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
  80. -- split alpha channel
  81. im = im:toTensor('float', 'RGBA', 'DHW')
  82. local sum_alpha = (im[4] - 1.0):sum()
  83. if sum_alpha < 0 then
  84. alpha = im[4]:reshape(1, im:size(2), im:size(3))
  85. -- drop full transparent background
  86. local mask = torch.le(alpha, 0.0)
  87. im[1][mask] = background_color
  88. im[2][mask] = background_color
  89. im[3][mask] = background_color
  90. end
  91. local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
  92. new_im[1]:copy(im[1])
  93. new_im[2]:copy(im[2])
  94. new_im[3]:copy(im[3])
  95. im = new_im:mul(255):byte()
  96. else
  97. im = im:toTensor('byte', 'RGB', 'DHW')
  98. end
  99. return {im, alpha, blob}
  100. end
  101. local state, ret = pcall(load_image)
  102. if state then
  103. return ret[1], ret[2], ret[3]
  104. else
  105. return nil, nil, nil
  106. end
  107. end
  108. function image_loader.load_float(file)
  109. local fp = io.open(file, "rb")
  110. if not fp then
  111. error(file .. ": failed to load image")
  112. end
  113. local buff = fp:read("*a")
  114. fp:close()
  115. return image_loader.decode_float(buff)
  116. end
  117. function image_loader.load_byte(file)
  118. local fp = io.open(file, "rb")
  119. if not fp then
  120. error(file .. ": failed to load image")
  121. end
  122. local buff = fp:read("*a")
  123. fp:close()
  124. return image_loader.decode_byte(buff)
  125. end
  126. local function test()
  127. require 'image'
  128. local img
  129. img = image_loader.load_float("./a.jpg")
  130. if img then
  131. print(img:min())
  132. print(img:max())
  133. image.display(img)
  134. end
  135. img = image_loader.load_float("./b.png")
  136. if img then
  137. image.display(img)
  138. end
  139. end
  140. --test()
  141. return image_loader