Browse Source

Fix embed gamma handling

nagadomi 9 years ago
parent
commit
fbad30c031
4 changed files with 61 additions and 55 deletions
  1. 2 2
      convert_data.lua
  2. 27 26
      lib/image_loader.lua
  3. 7 4
      waifu2x.lua
  4. 25 23
      web.lua

+ 2 - 2
convert_data.lua

@@ -33,8 +33,8 @@ local function load_images(list)
    local x = {}
    local x = {}
    for i = 1, #lines do
    for i = 1, #lines do
       local line = lines[i]
       local line = lines[i]
-      local im, alpha = image_loader.load_byte(line)
-      if alpha then
+      local im, meta = image_loader.load_byte(line)
+      if meta and meta.alpha then
 	 io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
 	 io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
       else
       else
 	 if settings.max_training_image_size > 0 then
 	 if settings.max_training_image_size > 0 then

+ 27 - 26
lib/image_loader.lua

@@ -9,14 +9,15 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
 local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
 local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
 local background_color = 0.5
 local background_color = 0.5
 
 
-function image_loader.encode_png(rgb, depth, inplace)
-   if inplace == nil then
-      inplace = false
+function image_loader.encode_png(rgb, options)
+   options = options or {}
+   options.depth = options.depth or 8
+   if options.inplace == nil then
+      options.inplace = false
    end
    end
-   depth = depth or 8
    rgb = iproc.byte2float(rgb)
    rgb = iproc.byte2float(rgb)
-   if depth < 16 then
-      if inplace then
+   if options.depth < 16 then
+      if options.inplace then
 	 rgb:add(clip_eps8)
 	 rgb:add(clip_eps8)
       else
       else
 	 rgb = rgb:clone():add(clip_eps8)
 	 rgb = rgb:clone():add(clip_eps8)
@@ -25,7 +26,7 @@ function image_loader.encode_png(rgb, depth, inplace)
       rgb[torch.gt(rgb, 1.0)] = 1.0
       rgb[torch.gt(rgb, 1.0)] = 1.0
       rgb = rgb:mul(255):floor():div(255)
       rgb = rgb:mul(255):floor():div(255)
    else
    else
-      if inplace then
+      if options.inplace then
 	 rgb:add(clip_eps16)
 	 rgb:add(clip_eps16)
       else
       else
 	 rgb = rgb:clone():add(clip_eps16)
 	 rgb = rgb:clone():add(clip_eps16)
@@ -43,10 +44,13 @@ function image_loader.encode_png(rgb, depth, inplace)
       im = gm.Image(rgb, "I", "DHW")
       im = gm.Image(rgb, "I", "DHW")
       -- im:colorspace("GRAY") -- it does not work
       -- im:colorspace("GRAY") -- it does not work
    end
    end
-   return im:depth(depth):format("PNG"):toString(9)
+   if options.gamma then
+      im:gamma(options.gamma)
+   end
+   return im:depth(options.depth):format("PNG"):toString(9)
 end
 end
-function image_loader.save_png(filename, rgb, depth, inplace)
-   local blob = image_loader.encode_png(rgb, depth, inplace)
+function image_loader.save_png(filename, rgb, options)
+   local blob = image_loader.encode_png(rgb, options)
    local fp = io.open(filename, "wb")
    local fp = io.open(filename, "wb")
    if not fp then
    if not fp then
       error("IO error: " .. filename)
       error("IO error: " .. filename)
@@ -57,8 +61,8 @@ function image_loader.save_png(filename, rgb, depth, inplace)
 end
 end
 function image_loader.decode_float(blob)
 function image_loader.decode_float(blob)
    local load_image = function()
    local load_image = function()
+      local meta = {}
       local im = gm.Image()
       local im = gm.Image()
-      local alpha = nil
       local gamma_lcd = 0.454545
       local gamma_lcd = 0.454545
       
       
       im:fromBlob(blob, #blob)
       im:fromBlob(blob, #blob)
@@ -66,12 +70,8 @@ function image_loader.decode_float(blob)
       if im:colorspace() == "CMYK" then
       if im:colorspace() == "CMYK" then
 	 im:colorspace("RGB")
 	 im:colorspace("RGB")
       end
       end
-      local gamma = math.floor(im:gamma() * 1000000) / 1000000
-      if gamma ~= 0 and gamma ~= gamma_lcd then
-	 local cg = gamma / gamma_lcd
-	 im:gammaCorrection(cg, "Red")
-	 im:gammaCorrection(cg, "Blue")
-	 im:gammaCorrection(cg, "Green")
+      if gamma ~= 0 and math.floor(im:gamma() * 1000000) / 1000000 ~= gamma_lcd then
+	 meta.gamma = im:gamma()
       end
       end
       -- FIXME: How to detect that a image has an alpha channel?
       -- FIXME: How to detect that a image has an alpha channel?
       if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
       if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
@@ -79,9 +79,9 @@ function image_loader.decode_float(blob)
 	 im = im:toTensor('float', 'RGBA', 'DHW')
 	 im = im:toTensor('float', 'RGBA', 'DHW')
 	 local sum_alpha = (im[4] - 1.0):sum()
 	 local sum_alpha = (im[4] - 1.0):sum()
 	 if sum_alpha < 0 then
 	 if sum_alpha < 0 then
-	    alpha = im[4]:reshape(1, im:size(2), im:size(3))
+	    meta.alpha = im[4]:reshape(1, im:size(2), im:size(3))
 	    -- drop full transparent background
 	    -- drop full transparent background
-	    local mask = torch.le(alpha, 0.0)
+	    local mask = torch.le(meta.alpha, 0.0)
 	    im[1][mask] = background_color
 	    im[1][mask] = background_color
 	    im[2][mask] = background_color
 	    im[2][mask] = background_color
 	    im[3][mask] = background_color
 	    im[3][mask] = background_color
@@ -94,25 +94,26 @@ function image_loader.decode_float(blob)
       else
       else
 	 im = im:toTensor('float', 'RGB', 'DHW')
 	 im = im:toTensor('float', 'RGB', 'DHW')
       end
       end
-      return {im, alpha, blob}
+      meta.blob = blob
+      return {im, meta}
    end
    end
    local state, ret = pcall(load_image)
    local state, ret = pcall(load_image)
    if state then
    if state then
-      return ret[1], ret[2], ret[3]
+      return ret[1], ret[2]
    else
    else
-      return nil, nil, nil
+      return nil, nil
    end
    end
 end
 end
 function image_loader.decode_byte(blob)
 function image_loader.decode_byte(blob)
-   local im, alpha
-   im, alpha, blob = image_loader.decode_float(blob)
+   local im, meta
+   im, meta = image_loader.decode_float(blob)
    
    
    if im then
    if im then
       im = iproc.float2byte(im)
       im = iproc.float2byte(im)
       -- hmm, alpha does not convert here
       -- hmm, alpha does not convert here
-      return im, alpha, blob
+      return im, meta
    else
    else
-      return nil, nil, nil
+      return nil, nil
    end
    end
 end
 end
 function image_loader.load_float(file)
 function image_loader.load_float(file)

+ 7 - 4
waifu2x.lua

@@ -11,7 +11,8 @@ local alpha_util = require 'alpha_util'
 torch.setdefaulttensortype('torch.FloatTensor')
 torch.setdefaulttensortype('torch.FloatTensor')
 
 
 local function convert_image(opt)
 local function convert_image(opt)
-   local x, alpha = image_loader.load_float(opt.i)
+   local x, meta = image_loader.load_float(opt.i)
+   local alpha = meta.alpha
    local new_x = nil
    local new_x = nil
    local t = sys.clock()
    local t = sys.clock()
    local scale_f, image_f
    local scale_f, image_f
@@ -65,7 +66,7 @@ local function convert_image(opt)
    else
    else
       error("undefined method:" .. opt.method)
       error("undefined method:" .. opt.method)
    end
    end
-   image_loader.save_png(opt.o, new_x, opt.depth, true)
+   image_loader.save_png(opt.o, new_x, {depth = opt.depth, inplace = true, gamma = meta.gamma})
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 end
 local function convert_frames(opt)
 local function convert_frames(opt)
@@ -115,7 +116,8 @@ local function convert_frames(opt)
    fp:close()
    fp:close()
    for i = 1, #lines do
    for i = 1, #lines do
       if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
       if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
-	 local x, alpha = image_loader.load_float(lines[i])
+	 local x, meta = image_loader.load_float(lines[i])
+	 local alpha = meta.alpha
 	 local new_x = nil
 	 local new_x = nil
 	 if opt.m == "noise" then
 	 if opt.m == "noise" then
 	    new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
 	    new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
@@ -141,7 +143,8 @@ local function convert_frames(opt)
 	 else
 	 else
 	    output = string.format(opt.o, i)
 	    output = string.format(opt.o, i)
 	 end
 	 end
-	 image_loader.save_png(output, new_x, opt.depth, true)
+	 image_loader.save_png(output, new_x, 
+			       {depth = opt.depth, inplace = true, gamma = meta.gamma})
 	 xlua.progress(i, #lines)
 	 xlua.progress(i, #lines)
 	 if i % 10 == 0 then
 	 if i % 10 == 0 then
 	    collectgarbage()
 	    collectgarbage()

+ 25 - 23
web.lua

@@ -93,7 +93,7 @@ local function cache_url(url)
 	 end
 	 end
       end
       end
    end
    end
-   return nil, nil, nil
+   return nil, nil
 end
 end
 local function get_image(req)
 local function get_image(req)
    local file_info = req:get_arguments("file")
    local file_info = req:get_arguments("file")
@@ -108,22 +108,23 @@ local function get_image(req)
       end
       end
    end
    end
    if file and file:len() > 0 then
    if file and file:len() > 0 then
-      local x, alpha, blob = image_loader.decode_float(file)
-      return x, alpha, blob, filename
+      local x, meta = image_loader.decode_float(file)
+      return x, meta, filename
    elseif url and url:len() > 0 then
    elseif url and url:len() > 0 then
-      local x, alpha, blob = cache_url(url)
-      return x, alpha, blob, filename
+      local x, meta = cache_url(url)
+      return x, meta, filename
    end
    end
-   return nil, nil, nil, nil
+   return nil, nil, nil
 end
 end
 local function cleanup_model(model)
 local function cleanup_model(model)
    if CLEANUP_MODEL then
    if CLEANUP_MODEL then
       model:clearState() -- release GPU memory
       model:clearState() -- release GPU memory
    end
    end
 end
 end
-local function convert(x, alpha, options)
+local function convert(x, meta, options)
    local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
    local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
    local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
    local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
+   local alpha = meta.alpha
    local alpha_orig = alpha
    local alpha_orig = alpha
 
 
    if path.exists(alpha_cache_file) then
    if path.exists(alpha_cache_file) then
@@ -137,7 +138,7 @@ local function convert(x, alpha, options)
    end
    end
    if path.exists(cache_file) then
    if path.exists(cache_file) then
       x = image_loader.load_float(cache_file)
       x = image_loader.load_float(cache_file)
-      return x, alpha
+      return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
    else
    else
       if options.style == "art" then
       if options.style == "art" then
 	 if options.border then
 	 if options.border then
@@ -192,7 +193,7 @@ local function convert(x, alpha, options)
       end
       end
       image_loader.save_png(cache_file, x)
       image_loader.save_png(cache_file, x)
 
 
-      return x, alpha
+      return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
    end
    end
 end
 end
 local function client_disconnected(handler)
 local function client_disconnected(handler)
@@ -218,7 +219,7 @@ function APIHandler:post()
       self:write("client disconnected")
       self:write("client disconnected")
       return
       return
    end
    end
-   local x, alpha, blob, filename = get_image(self)
+   local x, meta, filename = get_image(self)
    local scale = tonumber(self:get_argument("scale", "0"))
    local scale = tonumber(self:get_argument("scale", "0"))
    local noise = tonumber(self:get_argument("noise", "0"))
    local noise = tonumber(self:get_argument("noise", "0"))
    local style = self:get_argument("style", "art")
    local style = self:get_argument("style", "art")
@@ -230,29 +231,29 @@ function APIHandler:post()
    if x and valid_size(x, scale) then
    if x and valid_size(x, scale) then
       local prefix = nil
       local prefix = nil
       if (noise ~= 0 or scale ~= 0) then
       if (noise ~= 0 or scale ~= 0) then
-	 local hash = md5.sumhexa(blob)
+	 local hash = md5.sumhexa(meta.blob)
 	 local alpha_prefix = style .. "_" .. hash .. "_alpha"
 	 local alpha_prefix = style .. "_" .. hash .. "_alpha"
 	 local border = false
 	 local border = false
-	 if scale ~= 0 and alpha then
+	 if scale ~= 0 and meta.alpha then
 	    border = true
 	    border = true
 	 end
 	 end
 	 if noise == 1 then
 	 if noise == 1 then
 	    prefix = style .. "_noise1_"
 	    prefix = style .. "_noise1_"
-	    x = convert(x, alpha, {method = "noise1", style = style,
-				   prefix = prefix .. hash,
-				   alpha_prefix = alpha_prefix, border = border})
+	    x = convert(x, meta, {method = "noise1", style = style,
+				  prefix = prefix .. hash,
+				  alpha_prefix = alpha_prefix, border = border})
 	    border = false
 	    border = false
 	 elseif noise == 2 then
 	 elseif noise == 2 then
 	    prefix = style .. "_noise2_"
 	    prefix = style .. "_noise2_"
-	    x = convert(x, alpha, {method = "noise2", style = style,
-				   prefix = prefix .. hash, 
-				   alpha_prefix = alpha_prefix, border = border})
+	    x = convert(x, meta, {method = "noise2", style = style,
+				  prefix = prefix .. hash, 
+				  alpha_prefix = alpha_prefix, border = border})
 	    border = false
 	    border = false
 	 elseif noise == 3 then
 	 elseif noise == 3 then
 	    prefix = style .. "_noise3_"
 	    prefix = style .. "_noise3_"
-	    x = convert(x, alpha, {method = "noise3", style = style,
-				   prefix = prefix .. hash, 
-				   alpha_prefix = alpha_prefix, border = border})
+	    x = convert(x, meta, {method = "noise3", style = style,
+				  prefix = prefix .. hash, 
+				  alpha_prefix = alpha_prefix, border = border})
 	    border = false
 	    border = false
 	 end
 	 end
 	 if scale == 1 or scale == 2 then
 	 if scale == 1 or scale == 2 then
@@ -265,7 +266,7 @@ function APIHandler:post()
 	    else
 	    else
 	       prefix = style .. "_scale_"
 	       prefix = style .. "_scale_"
 	    end
 	    end
-	    x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
+	    x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
 	    if scale == 1 then
 	    if scale == 1 then
 	       x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
 	       x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
 	    end
 	    end
@@ -281,7 +282,8 @@ function APIHandler:post()
       else
       else
 	 name = uuid() .. ".png"
 	 name = uuid() .. ".png"
       end
       end
-      local blob = image_loader.encode_png(alpha_util.composite(x, alpha), 8, true)
+      local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
+					   { depth = 8, inplace = true, gamma = meta.gamma})
 
 
       self:set_header("Content-Length", string.format("%d", #blob))
       self:set_header("Content-Length", string.format("%d", #blob))
       if download > 0 then
       if download > 0 then