image_loader.lua 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. local gm = require 'graphicsmagick'
  2. local ffi = require 'ffi'
  3. local iproc = require 'iproc'
  4. require 'pl'
  5. local image_loader = {}
  6. local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
  7. local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
  8. local background_color = 0.5
  9. function image_loader.encode_png(rgb, alpha, depth)
  10. depth = depth or 8
  11. rgb = iproc.byte2float(rgb)
  12. if alpha then
  13. if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
  14. alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
  15. end
  16. local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
  17. rgba[1]:copy(rgb[1])
  18. rgba[2]:copy(rgb[2])
  19. rgba[3]:copy(rgb[3])
  20. rgba[4]:copy(alpha)
  21. if depth < 16 then
  22. rgba:add(clip_eps8)
  23. rgba[torch.lt(rgba, 0.0)] = 0.0
  24. rgba[torch.gt(rgba, 1.0)] = 1.0
  25. else
  26. rgba:add(clip_eps16)
  27. rgba[torch.lt(rgba, 0.0)] = 0.0
  28. rgba[torch.gt(rgba, 1.0)] = 1.0
  29. end
  30. local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
  31. return im:depth(depth):format("PNG"):toBlob(9)
  32. else
  33. if depth < 16 then
  34. rgb = rgb:clone():add(clip_eps8)
  35. rgb[torch.lt(rgb, 0.0)] = 0.0
  36. rgb[torch.gt(rgb, 1.0)] = 1.0
  37. else
  38. rgb = rgb:clone():add(clip_eps16)
  39. rgb[torch.lt(rgb, 0.0)] = 0.0
  40. rgb[torch.gt(rgb, 1.0)] = 1.0
  41. end
  42. local im = gm.Image(rgb, "RGB", "DHW")
  43. return im:depth(depth):format("PNG"):toBlob(9)
  44. end
  45. end
  46. function image_loader.save_png(filename, rgb, alpha, depth)
  47. depth = depth or 8
  48. local blob, len = image_loader.encode_png(rgb, alpha, depth)
  49. local fp = io.open(filename, "wb")
  50. if not fp then
  51. error("IO error: " .. filename)
  52. end
  53. fp:write(ffi.string(blob, len))
  54. fp:close()
  55. return true
  56. end
  57. function image_loader.decode_float(blob)
  58. local load_image = function()
  59. local im = gm.Image()
  60. local alpha = nil
  61. local gamma_lcd = 0.454545
  62. im:fromBlob(blob, #blob)
  63. if im:colorspace() == "CMYK" then
  64. im:colorspace("RGB")
  65. end
  66. local gamma = math.floor(im:gamma() * 1000000) / 1000000
  67. if gamma ~= 0 and gamma ~= gamma_lcd then
  68. im:gammaCorrection(gamma / gamma_lcd)
  69. end
  70. -- FIXME: How to detect that a image has an alpha channel?
  71. if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
  72. -- split alpha channel
  73. im = im:toTensor('float', 'RGBA', 'DHW')
  74. local sum_alpha = (im[4] - 1.0):sum()
  75. if sum_alpha < 0 then
  76. alpha = im[4]:reshape(1, im:size(2), im:size(3))
  77. -- drop full transparent background
  78. local mask = torch.le(alpha, 0.0)
  79. im[1][mask] = background_color
  80. im[2][mask] = background_color
  81. im[3][mask] = background_color
  82. end
  83. local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
  84. new_im[1]:copy(im[1])
  85. new_im[2]:copy(im[2])
  86. new_im[3]:copy(im[3])
  87. im = new_im
  88. else
  89. im = im:toTensor('float', 'RGB', 'DHW')
  90. end
  91. return {im, alpha, blob}
  92. end
  93. local state, ret = pcall(load_image)
  94. if state then
  95. return ret[1], ret[2], ret[3]
  96. else
  97. return nil, nil, nil
  98. end
  99. end
  100. function image_loader.decode_byte(blob)
  101. local im, alpha
  102. im, alpha, blob = image_loader.decode_float(blob)
  103. if im then
  104. im = iproc.float2byte(im)
  105. -- hmm, alpha does not convert here
  106. return im, alpha, blob
  107. else
  108. return nil, nil, nil
  109. end
  110. end
  111. function image_loader.load_float(file)
  112. local fp = io.open(file, "rb")
  113. if not fp then
  114. error(file .. ": failed to load image")
  115. end
  116. local buff = fp:read("*a")
  117. fp:close()
  118. return image_loader.decode_float(buff)
  119. end
  120. function image_loader.load_byte(file)
  121. local fp = io.open(file, "rb")
  122. if not fp then
  123. error(file .. ": failed to load image")
  124. end
  125. local buff = fp:read("*a")
  126. fp:close()
  127. return image_loader.decode_byte(buff)
  128. end
  129. local function test()
  130. torch.setdefaulttensortype("torch.FloatTensor")
  131. local a = image_loader.load_float("../images/lena.png")
  132. local blob, len = image_loader.encode_png(a)
  133. local b = image_loader.decode_float(ffi.string(blob, len))
  134. assert((b - a):abs():sum() == 0)
  135. a = image_loader.load_byte("../images/lena.png")
  136. blob, len = image_loader.encode_png(a)
  137. b = image_loader.decode_byte(ffi.string(blob, len))
  138. assert((b:float() - a:float()):abs():sum() == 0)
  139. end
  140. --test()
  141. return image_loader