|
@@ -11,7 +11,8 @@ function image_loader.decode_float(blob)
|
|
end
|
|
end
|
|
return im, alpha, blob
|
|
return im, alpha, blob
|
|
end
|
|
end
|
|
-function image_loader.encode_png(rgb, alpha)
|
|
|
|
|
|
+function image_loader.encode_png(rgb, alpha, depth)
|
|
|
|
+ depth = depth or 8
|
|
if rgb:type() == "torch.ByteTensor" then
|
|
if rgb:type() == "torch.ByteTensor" then
|
|
rgb = rgb:float():div(255)
|
|
rgb = rgb:float():div(255)
|
|
end
|
|
end
|
|
@@ -26,15 +27,16 @@ function image_loader.encode_png(rgb, alpha)
|
|
rgba[4]:copy(alpha)
|
|
rgba[4]:copy(alpha)
|
|
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
|
|
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
|
|
im:format("png")
|
|
im:format("png")
|
|
- return im:toBlob(9)
|
|
|
|
|
|
+ return im:depth(depth):toBlob(9)
|
|
else
|
|
else
|
|
local im = gm.Image(rgb, "RGB", "DHW")
|
|
local im = gm.Image(rgb, "RGB", "DHW")
|
|
im:format("png")
|
|
im:format("png")
|
|
- return im:toBlob(9)
|
|
|
|
|
|
+ return im:depth(depth):toBlob(9)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
-function image_loader.save_png(filename, rgb, alpha)
|
|
|
|
- local blob, len = image_loader.encode_png(rgb, alpha)
|
|
|
|
|
|
+function image_loader.save_png(filename, rgb, alpha, depth)
|
|
|
|
+ depth = depth or 8
|
|
|
|
+ local blob, len = image_loader.encode_png(rgb, alpha, depth)
|
|
local fp = io.open(filename, "wb")
|
|
local fp = io.open(filename, "wb")
|
|
if not fp then
|
|
if not fp then
|
|
error("IO error: " .. filename)
|
|
error("IO error: " .. filename)
|