浏览代码

directly load float data

nagadomi 9 年之前
父节点
当前提交
3c1c11d88e
共有 1 个文件被更改,包括 27 次插入25 次删除
  1. 27 25
      lib/image_loader.lua

+ 27 - 25
lib/image_loader.lua

@@ -1,5 +1,6 @@
 local gm = require 'graphicsmagick'
 local ffi = require 'ffi'
+local iproc = require 'iproc'
 require 'pl'
 
 local image_loader = {}
@@ -8,18 +9,9 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
 local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
 local background_color = 0.5
 
-function image_loader.decode_float(blob)
-   local im, alpha = image_loader.decode_byte(blob)
-   if im then
-      im = im:float():div(255)
-   end
-   return im, alpha, blob
-end
 function image_loader.encode_png(rgb, alpha, depth)
    depth = depth or 8
-   if rgb:type() == "torch.ByteTensor" then
-      rgb = rgb:float():div(255)
-   end
+   rgb = iproc.byte2float(rgb)
    if alpha then
       if not (alpha:size(2) == rgb:size(2) and  alpha:size(3) == rgb:size(3)) then
 	 alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
@@ -66,7 +58,7 @@ function image_loader.save_png(filename, rgb, alpha, depth)
    fp:close()
    return true
 end
-function image_loader.decode_byte(blob)
+function image_loader.decode_float(blob)
    local load_image = function()
       local im = gm.Image()
       local alpha = nil
@@ -98,9 +90,9 @@ function image_loader.decode_byte(blob)
 	 new_im[1]:copy(im[1])
 	 new_im[2]:copy(im[2])
 	 new_im[3]:copy(im[3])
-	 im = new_im:mul(255):byte()
+	 im = new_im
       else
-	 im = im:toTensor('byte', 'RGB', 'DHW')
+	 im = im:toTensor('float', 'RGB', 'DHW')
       end
       return {im, alpha, blob}
    end
@@ -111,6 +103,18 @@ function image_loader.decode_byte(blob)
       return nil, nil, nil
    end
 end
+function image_loader.decode_byte(blob)
+   local im, alpha
+   im, alpha, blob = image_loader.decode_float(blob)
+   
+   if im then
+      im = iproc.float2byte(im)
+      -- hmm, alpha does not convert here
+      return im, alpha, blob
+   else
+      return nil, nil, nil
+   end
+end
 function image_loader.load_float(file)
    local fp = io.open(file, "rb")
    if not fp then
@@ -130,18 +134,16 @@ function image_loader.load_byte(file)
    return image_loader.decode_byte(buff)
 end
 local function test()
-   require 'image'
-   local img
-   img = image_loader.load_float("./a.jpg")
-   if img then
-      print(img:min())
-      print(img:max())
-      image.display(img)
-   end
-   img = image_loader.load_float("./b.png")
-   if img then
-      image.display(img)
-   end
+   torch.setdefaulttensortype("torch.FloatTensor")
+   local a = image_loader.load_float("../images/lena.png")
+   local blob, len = image_loader.encode_png(a)
+   local b = image_loader.decode_float(ffi.string(blob, len))
+   assert((b - a):abs():sum() == 0)
+
+   a = image_loader.load_byte("../images/lena.png")
+   blob, len = image_loader.encode_png(a)
+   b = image_loader.decode_byte(ffi.string(blob, len))
+   assert((b:float() - a:float()):abs():sum() == 0)
 end
 --test()
 return image_loader