image_loader.lua 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. local gm = require 'graphicsmagick'
  2. local ffi = require 'ffi'
  3. local iproc = require 'iproc'
  4. require 'pl'
  5. local image_loader = {}
  6. local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
  7. local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
  8. local background_color = 0.5
  9. function image_loader.encode_png(rgb, options)
  10. options = options or {}
  11. options.depth = options.depth or 8
  12. if options.inplace == nil then
  13. options.inplace = false
  14. end
  15. rgb = iproc.byte2float(rgb)
  16. if options.depth < 16 then
  17. if options.inplace then
  18. rgb:add(clip_eps8)
  19. else
  20. rgb = rgb:clone():add(clip_eps8)
  21. end
  22. rgb[torch.lt(rgb, 0.0)] = 0.0
  23. rgb[torch.gt(rgb, 1.0)] = 1.0
  24. rgb = rgb:mul(255):floor():div(255)
  25. else
  26. if options.inplace then
  27. rgb:add(clip_eps16)
  28. else
  29. rgb = rgb:clone():add(clip_eps16)
  30. end
  31. rgb[torch.lt(rgb, 0.0)] = 0.0
  32. rgb[torch.gt(rgb, 1.0)] = 1.0
  33. rgb = rgb:mul(65535):floor():div(65535)
  34. end
  35. local im
  36. if rgb:size(1) == 4 then -- RGBA
  37. im = gm.Image(rgb, "RGBA", "DHW")
  38. if options.grayscale then
  39. im:type("GrayscaleMatte")
  40. end
  41. elseif rgb:size(1) == 3 then -- RGB
  42. im = gm.Image(rgb, "RGB", "DHW")
  43. if options.grayscale then
  44. im:type("Grayscale")
  45. end
  46. elseif rgb:size(1) == 1 then -- Y
  47. im = gm.Image(rgb, "I", "DHW")
  48. im:type("Grayscale")
  49. end
  50. if options.gamma then
  51. im:gamma(options.gamma)
  52. end
  53. return im:depth(options.depth):format("PNG"):toString(9)
  54. end
  55. function image_loader.save_png(filename, rgb, options)
  56. local blob = image_loader.encode_png(rgb, options)
  57. local fp = io.open(filename, "wb")
  58. if not fp then
  59. error("IO error: " .. filename)
  60. end
  61. fp:write(blob)
  62. fp:close()
  63. return true
  64. end
  65. function image_loader.decode_float(blob)
  66. local load_image = function()
  67. local meta = {}
  68. local im = gm.Image()
  69. local gamma_lcd = 0.454545
  70. im:fromBlob(blob, #blob)
  71. if im:colorspace() == "CMYK" then
  72. im:colorspace("RGB")
  73. end
  74. if gamma ~= 0 and math.floor(im:gamma() * 1000000) / 1000000 ~= gamma_lcd then
  75. meta.gamma = im:gamma()
  76. end
  77. local image_type = im:type()
  78. if image_type == "Grayscale" or image_type == "GrayscaleMatte" then
  79. meta.grayscale = true
  80. end
  81. if image_type == "TrueColorMatte" or image_type == "GrayscaleMatte" then
  82. -- split alpha channel
  83. im = im:toTensor('float', 'RGBA', 'DHW')
  84. meta.alpha = im[4]:reshape(1, im:size(2), im:size(3))
  85. -- drop full transparent background
  86. local mask = torch.le(meta.alpha, 0.0)
  87. im[1][mask] = background_color
  88. im[2][mask] = background_color
  89. im[3][mask] = background_color
  90. local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
  91. new_im[1]:copy(im[1])
  92. new_im[2]:copy(im[2])
  93. new_im[3]:copy(im[3])
  94. im = new_im
  95. else
  96. im = im:toTensor('float', 'RGB', 'DHW')
  97. end
  98. meta.blob = blob
  99. return {im, meta}
  100. end
  101. local state, ret = pcall(load_image)
  102. if state then
  103. return ret[1], ret[2]
  104. else
  105. return nil, nil
  106. end
  107. end
  108. function image_loader.decode_byte(blob)
  109. local im, meta
  110. im, meta = image_loader.decode_float(blob)
  111. if im then
  112. im = iproc.float2byte(im)
  113. -- hmm, alpha does not convert here
  114. return im, meta
  115. else
  116. return nil, nil
  117. end
  118. end
  119. function image_loader.load_float(file)
  120. local fp = io.open(file, "rb")
  121. if not fp then
  122. error(file .. ": failed to load image")
  123. end
  124. local buff = fp:read("*a")
  125. fp:close()
  126. return image_loader.decode_float(buff)
  127. end
  128. function image_loader.load_byte(file)
  129. local fp = io.open(file, "rb")
  130. if not fp then
  131. error(file .. ": failed to load image")
  132. end
  133. local buff = fp:read("*a")
  134. fp:close()
  135. return image_loader.decode_byte(buff)
  136. end
  137. local function test()
  138. torch.setdefaulttensortype("torch.FloatTensor")
  139. local a = image_loader.load_float("../images/lena.png")
  140. local blob = image_loader.encode_png(a)
  141. local b = image_loader.decode_float(blob)
  142. assert((b - a):abs():sum() == 0)
  143. a = image_loader.load_byte("../images/lena.png")
  144. blob = image_loader.encode_png(a)
  145. b = image_loader.decode_byte(blob)
  146. assert((b:float() - a:float()):abs():sum() == 0)
  147. end
  148. --test()
  149. return image_loader