image_loader.lua 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. local gm = require 'graphicsmagick'
  2. local ffi = require 'ffi'
  3. local iproc = require 'iproc'
  4. require 'pl'
  5. local image_loader = {}
  6. local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
  7. local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
  8. local background_color = 0.5
  9. function image_loader.encode_png(rgb, depth, inplace)
  10. if inplace == nil then
  11. inplace = false
  12. end
  13. depth = depth or 8
  14. rgb = iproc.byte2float(rgb)
  15. if depth < 16 then
  16. if inplace then
  17. rgb:add(clip_eps8)
  18. else
  19. rgb = rgb:clone():add(clip_eps8)
  20. end
  21. rgb[torch.lt(rgb, 0.0)] = 0.0
  22. rgb[torch.gt(rgb, 1.0)] = 1.0
  23. rgb = rgb:mul(255):floor():div(255)
  24. else
  25. if inplace then
  26. rgb:add(clip_eps16)
  27. else
  28. rgb = rgb:clone():add(clip_eps16)
  29. end
  30. rgb[torch.lt(rgb, 0.0)] = 0.0
  31. rgb[torch.gt(rgb, 1.0)] = 1.0
  32. rgb = rgb:mul(65535):floor():div(65535)
  33. end
  34. local im
  35. if rgb:size(1) == 4 then -- RGBA
  36. im = gm.Image(rgb, "RGBA", "DHW")
  37. elseif rgb:size(1) == 3 then -- RGB
  38. im = gm.Image(rgb, "RGB", "DHW")
  39. elseif rgb:size(1) == 1 then -- Y
  40. im = gm.Image(rgb, "I", "DHW")
  41. -- im:colorspace("GRAY") -- it does not work
  42. end
  43. return im:depth(depth):format("PNG"):toString(9)
  44. end
  45. function image_loader.save_png(filename, rgb, depth, inplace)
  46. local blob = image_loader.encode_png(rgb, depth, inplace)
  47. local fp = io.open(filename, "wb")
  48. if not fp then
  49. error("IO error: " .. filename)
  50. end
  51. fp:write(blob)
  52. fp:close()
  53. return true
  54. end
  55. function image_loader.decode_float(blob)
  56. local load_image = function()
  57. local im = gm.Image()
  58. local alpha = nil
  59. local gamma_lcd = 0.454545
  60. im:fromBlob(blob, #blob)
  61. if im:colorspace() == "CMYK" then
  62. im:colorspace("RGB")
  63. end
  64. local gamma = math.floor(im:gamma() * 1000000) / 1000000
  65. if gamma ~= 0 and gamma ~= gamma_lcd then
  66. local cg = gamma / gamma_lcd
  67. im:gammaCorrection(cg, "Red")
  68. im:gammaCorrection(cg, "Blue")
  69. im:gammaCorrection(cg, "Green")
  70. end
  71. -- FIXME: How to detect that a image has an alpha channel?
  72. if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
  73. -- split alpha channel
  74. im = im:toTensor('float', 'RGBA', 'DHW')
  75. local sum_alpha = (im[4] - 1.0):sum()
  76. if sum_alpha < 0 then
  77. alpha = im[4]:reshape(1, im:size(2), im:size(3))
  78. -- drop full transparent background
  79. local mask = torch.le(alpha, 0.0)
  80. im[1][mask] = background_color
  81. im[2][mask] = background_color
  82. im[3][mask] = background_color
  83. end
  84. local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
  85. new_im[1]:copy(im[1])
  86. new_im[2]:copy(im[2])
  87. new_im[3]:copy(im[3])
  88. im = new_im
  89. else
  90. im = im:toTensor('float', 'RGB', 'DHW')
  91. end
  92. return {im, alpha, blob}
  93. end
  94. local state, ret = pcall(load_image)
  95. if state then
  96. return ret[1], ret[2], ret[3]
  97. else
  98. return nil, nil, nil
  99. end
  100. end
  101. function image_loader.decode_byte(blob)
  102. local im, alpha
  103. im, alpha, blob = image_loader.decode_float(blob)
  104. if im then
  105. im = iproc.float2byte(im)
  106. -- hmm, alpha does not convert here
  107. return im, alpha, blob
  108. else
  109. return nil, nil, nil
  110. end
  111. end
  112. function image_loader.load_float(file)
  113. local fp = io.open(file, "rb")
  114. if not fp then
  115. error(file .. ": failed to load image")
  116. end
  117. local buff = fp:read("*a")
  118. fp:close()
  119. return image_loader.decode_float(buff)
  120. end
  121. function image_loader.load_byte(file)
  122. local fp = io.open(file, "rb")
  123. if not fp then
  124. error(file .. ": failed to load image")
  125. end
  126. local buff = fp:read("*a")
  127. fp:close()
  128. return image_loader.decode_byte(buff)
  129. end
  130. local function test()
  131. torch.setdefaulttensortype("torch.FloatTensor")
  132. local a = image_loader.load_float("../images/lena.png")
  133. local blob = image_loader.encode_png(a)
  134. local b = image_loader.decode_float(blob)
  135. assert((b - a):abs():sum() == 0)
  136. a = image_loader.load_byte("../images/lena.png")
  137. blob = image_loader.encode_png(a)
  138. b = image_loader.decode_byte(blob)
  139. assert((b:float() - a:float()):abs():sum() == 0)
  140. end
  141. --test()
  142. return image_loader