|
@@ -73,19 +73,16 @@ function image_loader.decode_float(blob)
|
|
|
if gamma ~= 0 and math.floor(im:gamma() * 1000000) / 1000000 ~= gamma_lcd then
|
|
|
meta.gamma = im:gamma()
|
|
|
end
|
|
|
- -- FIXME: How to detect that a image has an alpha channel?
|
|
|
- if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
|
|
+ local image_type = im:type()
|
|
|
+ if image_type == "TrueColorMatte" or image_type == "GrayscaleMatte" then
|
|
|
-- split alpha channel
|
|
|
im = im:toTensor('float', 'RGBA', 'DHW')
|
|
|
- local sum_alpha = (im[4] - 1.0):sum()
|
|
|
- if sum_alpha < 0 then
|
|
|
- meta.alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
|
|
- -- drop full transparent background
|
|
|
- local mask = torch.le(meta.alpha, 0.0)
|
|
|
- im[1][mask] = background_color
|
|
|
- im[2][mask] = background_color
|
|
|
- im[3][mask] = background_color
|
|
|
- end
|
|
|
+ meta.alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
|
|
+ -- drop full transparent background
|
|
|
+ local mask = torch.le(meta.alpha, 0.0)
|
|
|
+ im[1][mask] = background_color
|
|
|
+ im[2][mask] = background_color
|
|
|
+ im[3][mask] = background_color
|
|
|
local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
|
|
|
new_im[1]:copy(im[1])
|
|
|
new_im[2]:copy(im[2])
|