image_loader.lua 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. local gm = require 'graphicsmagick'
  2. local ffi = require 'ffi'
  3. local iproc = require 'iproc'
  4. local sRGB2014 = require 'sRGB2014'
  5. require 'pl'
  6. local image_loader = {}
  7. local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
  8. local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
  9. local background_color = 0.5
  10. function image_loader.encode_png(rgb, options)
  11. options = options or {}
  12. options.depth = options.depth or 8
  13. if options.inplace == nil then
  14. options.inplace = false
  15. end
  16. rgb = iproc.byte2float(rgb)
  17. if options.depth < 16 then
  18. if options.inplace then
  19. rgb:add(clip_eps8)
  20. else
  21. rgb = rgb:clone():add(clip_eps8)
  22. end
  23. rgb[torch.lt(rgb, 0.0)] = 0.0
  24. rgb[torch.gt(rgb, 1.0)] = 1.0
  25. rgb = rgb:mul(255):floor():div(255)
  26. else
  27. if options.inplace then
  28. rgb:add(clip_eps16)
  29. else
  30. rgb = rgb:clone():add(clip_eps16)
  31. end
  32. rgb[torch.lt(rgb, 0.0)] = 0.0
  33. rgb[torch.gt(rgb, 1.0)] = 1.0
  34. rgb = rgb:mul(65535):floor():div(65535)
  35. end
  36. local im
  37. if rgb:size(1) == 4 then -- RGBA
  38. im = gm.Image(rgb, "RGBA", "DHW")
  39. if options.grayscale then
  40. im:type("GrayscaleMatte")
  41. end
  42. elseif rgb:size(1) == 3 then -- RGB
  43. im = gm.Image(rgb, "RGB", "DHW")
  44. if options.grayscale then
  45. im:type("Grayscale")
  46. end
  47. elseif rgb:size(1) == 1 then -- Y
  48. im = gm.Image(rgb, "I", "DHW")
  49. im:type("Grayscale")
  50. end
  51. if options.gamma then
  52. im:gamma(options.gamma)
  53. end
  54. if options.icm and im.profile then
  55. im:profile("icm", sRGB2014)
  56. im:profile("icm", options.icm)
  57. end
  58. return im:depth(options.depth):format("PNG"):toString()
  59. end
  60. function image_loader.save_png(filename, rgb, options)
  61. local blob = image_loader.encode_png(rgb, options)
  62. local fp = io.open(filename, "wb")
  63. if not fp then
  64. error("IO error: " .. filename)
  65. end
  66. fp:write(blob)
  67. fp:close()
  68. return true
  69. end
  70. function image_loader.decode_float(blob)
  71. local load_image = function()
  72. local meta = {}
  73. local im = gm.Image()
  74. local gamma_lcd = 0.454545
  75. im:fromBlob(blob, #blob)
  76. if im.profile then
  77. meta.icm = im:profile("icm")
  78. if meta.icm then
  79. im:profile("icm", sRGB2014)
  80. im:removeProfile()
  81. end
  82. end
  83. if im:colorspace() == "CMYK" then
  84. im:colorspace("RGB")
  85. end
  86. if gamma ~= 0 and math.floor(im:gamma() * 1000000) / 1000000 ~= gamma_lcd then
  87. meta.gamma = im:gamma()
  88. end
  89. local image_type = im:type()
  90. if image_type == "Grayscale" or image_type == "GrayscaleMatte" then
  91. meta.grayscale = true
  92. end
  93. if image_type == "TrueColorMatte" or image_type == "GrayscaleMatte" then
  94. -- split alpha channel
  95. im = im:toTensor('float', 'RGBA', 'DHW')
  96. meta.alpha = im[4]:reshape(1, im:size(2), im:size(3))
  97. -- drop full transparent background
  98. local mask = torch.le(meta.alpha, 0.0)
  99. im[1][mask] = background_color
  100. im[2][mask] = background_color
  101. im[3][mask] = background_color
  102. local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
  103. new_im[1]:copy(im[1])
  104. new_im[2]:copy(im[2])
  105. new_im[3]:copy(im[3])
  106. im = new_im
  107. else
  108. im = im:toTensor('float', 'RGB', 'DHW')
  109. end
  110. meta.blob = blob
  111. return {im, meta}
  112. end
  113. local state, ret = pcall(load_image)
  114. if state then
  115. return ret[1], ret[2]
  116. else
  117. return nil, nil
  118. end
  119. end
  120. function image_loader.decode_byte(blob)
  121. local im, meta
  122. im, meta = image_loader.decode_float(blob)
  123. if im then
  124. im = iproc.float2byte(im)
  125. -- hmm, alpha does not convert here
  126. return im, meta
  127. else
  128. return nil, nil
  129. end
  130. end
  131. function image_loader.load_float(file)
  132. local fp = io.open(file, "rb")
  133. if not fp then
  134. error(file .. ": failed to load image")
  135. end
  136. local buff = fp:read("*a")
  137. fp:close()
  138. return image_loader.decode_float(buff)
  139. end
  140. function image_loader.load_byte(file)
  141. local fp = io.open(file, "rb")
  142. if not fp then
  143. error(file .. ": failed to load image")
  144. end
  145. local buff = fp:read("*a")
  146. fp:close()
  147. return image_loader.decode_byte(buff)
  148. end
  149. local function test()
  150. torch.setdefaulttensortype("torch.FloatTensor")
  151. local a = image_loader.load_float("../images/lena.png")
  152. local blob = image_loader.encode_png(a)
  153. local b = image_loader.decode_float(blob)
  154. assert((b - a):abs():sum() == 0)
  155. a = image_loader.load_byte("../images/lena.png")
  156. blob = image_loader.encode_png(a)
  157. b = image_loader.decode_byte(blob)
  158. assert((b:float() - a:float()):abs():sum() == 0)
  159. end
  160. --test()
  161. return image_loader