| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 | local gm = require 'graphicsmagick'local ffi = require 'ffi'require 'pl'local image_loader = {}function image_loader.decode_float(blob)   local im, alpha = image_loader.decode_byte(blob)   if im then      im = im:float():div(255)   end   return im, alphaendfunction image_loader.encode_png(rgb, alpha)   if rgb:type() == "torch.ByteTensor" then      error("expect FloatTensor")   end   if alpha then      if not (alpha:size(2) == rgb:size(2) and  alpha:size(3) == rgb:size(3)) then	 alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")      end      local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))      rgba[1]:copy(rgb[1])      rgba[2]:copy(rgb[2])      rgba[3]:copy(rgb[3])      rgba[4]:copy(alpha)      local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")      im:format("png")      return im:toBlob()   else      local im = gm.Image(rgb, "RGB", "DHW")      im:format("png")      return im:toBlob()   endendfunction image_loader.save_png(filename, rgb, alpha)   local blob, len = image_loader.encode_png(rgb, alpha)   local fp = io.open(filename, "wb")   fp:write(ffi.string(blob, len))   fp:close()   return trueendfunction image_loader.decode_byte(blob)   local load_image = function()      local im = gm.Image()      local alpha = nil            im:fromBlob(blob, #blob)      -- FIXME: How to detect that a image has an alpha channel?      if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then	 -- split alpha channel	 im = im:toTensor('float', 'RGBA', 'DHW')	 local sum_alpha = (im[4] - 1):sum()	 if sum_alpha > 0 or sum_alpha < 0 then	    alpha = im[4]:reshape(1, im:size(2), im:size(3))	 end	 local new_im = torch.FloatTensor(3, im:size(2), im:size(3))	 new_im[1]:copy(im[1])	 new_im[2]:copy(im[2])	 new_im[3]:copy(im[3])	 im = new_im:mul(255):byte()      else	 im = im:toTensor('byte', 'RGB', 'DHW')      end      return {im, alpha}   end   local state, ret = pcall(load_image)   if state then      return ret[1], ret[2]   else      return nil   endendfunction image_loader.load_float(file)   local fp = io.open(file, "rb")   if not fp then      error(file .. ": failed to load image")   end   local buff = fp:read("*a")   fp:close()   return image_loader.decode_float(buff)endfunction image_loader.load_byte(file)   local fp = io.open(file, "rb")   if not fp then      error(file .. ": failed to load image")   end   local buff = fp:read("*a")   fp:close()   return image_loader.decode_byte(buff)endlocal function test()   require 'image'   local img   img = image_loader.load_float("./a.jpg")   if img then      print(img:min())      print(img:max())      image.display(img)   end   img = image_loader.load_float("./b.png")   if img then      image.display(img)   endend--test()return image_loader
 |