|
@@ -1,5 +1,6 @@
|
|
|
local gm = require 'graphicsmagick'
|
|
|
local ffi = require 'ffi'
|
|
|
+local iproc = require 'iproc'
|
|
|
require 'pl'
|
|
|
|
|
|
local image_loader = {}
|
|
@@ -8,18 +9,9 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
|
|
local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
|
|
|
local background_color = 0.5
|
|
|
|
|
|
-function image_loader.decode_float(blob)
|
|
|
- local im, alpha = image_loader.decode_byte(blob)
|
|
|
- if im then
|
|
|
- im = im:float():div(255)
|
|
|
- end
|
|
|
- return im, alpha, blob
|
|
|
-end
|
|
|
function image_loader.encode_png(rgb, alpha, depth)
|
|
|
depth = depth or 8
|
|
|
- if rgb:type() == "torch.ByteTensor" then
|
|
|
- rgb = rgb:float():div(255)
|
|
|
- end
|
|
|
+ rgb = iproc.byte2float(rgb)
|
|
|
if alpha then
|
|
|
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
|
|
|
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
|
|
@@ -66,7 +58,7 @@ function image_loader.save_png(filename, rgb, alpha, depth)
|
|
|
fp:close()
|
|
|
return true
|
|
|
end
|
|
|
-function image_loader.decode_byte(blob)
|
|
|
+function image_loader.decode_float(blob)
|
|
|
local load_image = function()
|
|
|
local im = gm.Image()
|
|
|
local alpha = nil
|
|
@@ -98,9 +90,9 @@ function image_loader.decode_byte(blob)
|
|
|
new_im[1]:copy(im[1])
|
|
|
new_im[2]:copy(im[2])
|
|
|
new_im[3]:copy(im[3])
|
|
|
- im = new_im:mul(255):byte()
|
|
|
+ im = new_im
|
|
|
else
|
|
|
- im = im:toTensor('byte', 'RGB', 'DHW')
|
|
|
+ im = im:toTensor('float', 'RGB', 'DHW')
|
|
|
end
|
|
|
return {im, alpha, blob}
|
|
|
end
|
|
@@ -111,6 +103,18 @@ function image_loader.decode_byte(blob)
|
|
|
return nil, nil, nil
|
|
|
end
|
|
|
end
|
|
|
+function image_loader.decode_byte(blob)
|
|
|
+ local im, alpha
|
|
|
+ im, alpha, blob = image_loader.decode_float(blob)
|
|
|
+
|
|
|
+ if im then
|
|
|
+ im = iproc.float2byte(im)
|
|
|
+ -- hmm, alpha does not convert here
|
|
|
+ return im, alpha, blob
|
|
|
+ else
|
|
|
+ return nil, nil, nil
|
|
|
+ end
|
|
|
+end
|
|
|
function image_loader.load_float(file)
|
|
|
local fp = io.open(file, "rb")
|
|
|
if not fp then
|
|
@@ -130,18 +134,16 @@ function image_loader.load_byte(file)
|
|
|
return image_loader.decode_byte(buff)
|
|
|
end
|
|
|
local function test()
|
|
|
- require 'image'
|
|
|
- local img
|
|
|
- img = image_loader.load_float("./a.jpg")
|
|
|
- if img then
|
|
|
- print(img:min())
|
|
|
- print(img:max())
|
|
|
- image.display(img)
|
|
|
- end
|
|
|
- img = image_loader.load_float("./b.png")
|
|
|
- if img then
|
|
|
- image.display(img)
|
|
|
- end
|
|
|
+ torch.setdefaulttensortype("torch.FloatTensor")
|
|
|
+ local a = image_loader.load_float("../images/lena.png")
|
|
|
+ local blob, len = image_loader.encode_png(a)
|
|
|
+ local b = image_loader.decode_float(ffi.string(blob, len))
|
|
|
+ assert((b - a):abs():sum() == 0)
|
|
|
+
|
|
|
+ a = image_loader.load_byte("../images/lena.png")
|
|
|
+ blob, len = image_loader.encode_png(a)
|
|
|
+ b = image_loader.decode_byte(ffi.string(blob, len))
|
|
|
+ assert((b:float() - a:float()):abs():sum() == 0)
|
|
|
end
|
|
|
--test()
|
|
|
return image_loader
|