Переглянути джерело

Improve alpha channel handling #29

- make border
- scale the alpha channel by waifu2x
- composite
nagadomi 9 роки тому
батько
коміт
d2c081bbcf
6 змінених файлів з 225 додано та 71 видалено
  1. 80 0
      lib/alpha_util.lua
  2. 20 35
      lib/image_loader.lua
  3. 18 3
      lib/iproc.lua
  4. 34 8
      lib/reconstruct.lua
  5. 17 11
      waifu2x.lua
  6. 56 14
      web.lua

+ 80 - 0
lib/alpha_util.lua

@@ -0,0 +1,80 @@
+local w2nn = require 'w2nn'
+local reconstruct = require 'reconstruct'
+local image = require 'image'
+local iproc = require 'iproc'
+local gm = require 'graphicsmagick'
+
+alpha_util = {}
+alpha_util.sum2d = nn.SpatialConvolutionMM(1, 1, 3, 3, 1, 1, 1, 1):cuda()
+alpha_util.sum2d.weight:fill(1)
+alpha_util.sum2d.bias:zero()
+
+function alpha_util.make_border(rgb, alpha, offset)
+   if not alpha then
+      return rgb
+   end
+   
+   local mask = alpha:clone()
+   mask[torch.gt(mask, 0.0)] = 1
+   mask[torch.eq(mask, 0.0)] = 0
+   local mask_nega = (mask - 1):abs():byte()
+   local eps = 1.0e-7
+
+   rgb = rgb:clone()
+   rgb[1][mask_nega] = 0
+   rgb[2][mask_nega] = 0
+   rgb[3][mask_nega] = 0
+
+   for i = 1, offset do
+      local mask_weight = alpha_util.sum2d:forward(mask:cuda()):float()
+      local border = rgb:clone()
+      for j = 1, 3 do
+	 border[j]:copy(alpha_util.sum2d:forward(rgb[j]:reshape(1, rgb:size(2), rgb:size(3)):cuda()))
+	 border[j]:cdiv((mask_weight + eps))
+	 rgb[j][mask_nega] = border[j][mask_nega]
+      end
+      mask = mask_weight:clone()
+      mask[torch.gt(mask_weight, 0.0)] = 1
+      mask_nega = (mask - 1):abs():byte()
+   end
+   rgb[torch.gt(rgb, 1.0)] = 1.0
+   rgb[torch.lt(rgb, 0.0)] = 0.0
+
+   return rgb
+end
+function alpha_util.composite(rgb, alpha, model2x)
+   if not alpha then
+      return rgb
+   end
+   if not (alpha:size(2) == rgb:size(2) and  alpha:size(3) == rgb:size(3)) then
+      if model2x then
+	 alpha = reconstruct.scale(model2x, 2.0, alpha)
+      else
+	 alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
+      end
+   end
+   local out = torch.Tensor(4, rgb:size(2), rgb:size(3))
+   out[1]:copy(rgb[1])
+   out[2]:copy(rgb[2])
+   out[3]:copy(rgb[3])
+   out[4]:copy(alpha)
+   return out
+end
+
+local function test()
+   require 'sys'
+   require 'trepl'
+   torch.setdefaulttensortype("torch.FloatTensor")
+
+   local image_loader = require 'image_loader'
+   local rgb, alpha = image_loader.load_float("alpha.png")
+   local t = sys.clock()
+   rgb = alpha_util.make_border(rgb, alpha, 7)
+   print(sys.clock() - t)
+   print(rgb:min(), rgb:max())
+   image.display({image = rgb, min = 0, max = 1})
+   image.save("out.png", rgb)
+end
+--test()
+
+return alpha_util

+ 20 - 35
lib/image_loader.lua

@@ -9,47 +9,32 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
 local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
 local background_color = 0.5
 
-function image_loader.encode_png(rgb, alpha, depth)
+function image_loader.encode_png(rgb, depth)
    depth = depth or 8
    rgb = iproc.byte2float(rgb)
-   if alpha then
-      if not (alpha:size(2) == rgb:size(2) and  alpha:size(3) == rgb:size(3)) then
-	 alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
-      end
-      local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
-      rgba[1]:copy(rgb[1])
-      rgba[2]:copy(rgb[2])
-      rgba[3]:copy(rgb[3])
-      rgba[4]:copy(alpha)
-      
-      if depth < 16 then
-	 rgba:add(clip_eps8)
-	 rgba[torch.lt(rgba, 0.0)] = 0.0
-	 rgba[torch.gt(rgba, 1.0)] = 1.0
-      else
-	 rgba:add(clip_eps16)
-	 rgba[torch.lt(rgba, 0.0)] = 0.0
-	 rgba[torch.gt(rgba, 1.0)] = 1.0
-      end
-      local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
-      return im:depth(depth):format("PNG"):toString(9)
+   if depth < 16 then
+      rgb = rgb:clone():add(clip_eps8)
+      rgb[torch.lt(rgb, 0.0)] = 0.0
+      rgb[torch.gt(rgb, 1.0)] = 1.0
    else
-      if depth < 16 then
-	 rgb = rgb:clone():add(clip_eps8)
-	 rgb[torch.lt(rgb, 0.0)] = 0.0
-	 rgb[torch.gt(rgb, 1.0)] = 1.0
-      else
-	 rgb = rgb:clone():add(clip_eps16)
-	 rgb[torch.lt(rgb, 0.0)] = 0.0
-	 rgb[torch.gt(rgb, 1.0)] = 1.0
-      end
-      local im = gm.Image(rgb, "RGB", "DHW")
-      return im:depth(depth):format("PNG"):toString(9)
+      rgb = rgb:clone():add(clip_eps16)
+      rgb[torch.lt(rgb, 0.0)] = 0.0
+      rgb[torch.gt(rgb, 1.0)] = 1.0
+   end
+   local im
+   if rgb:size(1) == 4 then -- RGBA
+      im = gm.Image(rgb, "RGBA", "DHW")
+   elseif rgb:size(1) == 3 then -- RGB
+      im = gm.Image(rgb, "RGB", "DHW")
+   elseif rgb:size(1) == 1 then -- Y
+      im = gm.Image(rgb, "I", "DHW")
+      -- im:colorspace("GRAY") -- it does not work
    end
+   return im:depth(depth):format("PNG"):toString(9)
 end
-function image_loader.save_png(filename, rgb, alpha, depth)
+function image_loader.save_png(filename, rgb, depth)
    depth = depth or 8
-   local blob = image_loader.encode_png(rgb, alpha, depth)
+   local blob = image_loader.encode_png(rgb, depth)
    local fp = io.open(filename, "wb")
    if not fp then
       error("IO error: " .. filename)

+ 18 - 3
lib/iproc.lua

@@ -49,12 +49,17 @@ function iproc.float2byte(src)
    return dest, conversion
 end
 function iproc.scale(src, width, height, filter)
-   local conversion
+   local conversion, color
    src, conversion = iproc.byte2float(src)
    filter = filter or "Box"
-   local im = gm.Image(src, "RGB", "DHW")
+   if src:size(1) == 3 then
+      color = "RGB"
+   else
+      color = "I"
+   end
+   local im = gm.Image(src, color, "DHW")
    im:size(math.ceil(width), math.ceil(height), filter)
-   local dest = im:toTensor("float", "RGB", "DHW")
+   local dest = im:toTensor("float", color, "DHW")
    if conversion then
       dest = iproc.float2byte(dest)
    end
@@ -84,6 +89,16 @@ function iproc.padding(img, w1, w2, h1, h2)
    flow[2]:add(-w1)
    return image.warp(img, flow, "simple", false, "clamp")
 end
+function iproc.zero_padding(img, w1, w2, h1, h2)
+   local dst_height = img:size(2) + h1 + h2
+   local dst_width = img:size(3) + w1 + w2
+   local flow = torch.Tensor(2, dst_height, dst_width)
+   flow[1] = torch.ger(torch.linspace(0, dst_height -1, dst_height), torch.ones(dst_width))
+   flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
+   flow[1]:add(-h1)
+   flow[2]:add(-w1)
+   return image.warp(img, flow, "simple", false, "pad", 0)
+end
 function iproc.white_noise(src, std, rgb_weights, gamma)
    gamma = gamma or 0.454545
    local conversion

+ 34 - 8
lib/reconstruct.lua

@@ -189,22 +189,48 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
 end
 
 function reconstruct.image(model, x, block_size)
+   local i2rgb = false
+   if x:size(1) == 1 then
+      local new_x = torch.Tensor(3, x:size(2), x:size(3))
+      new_x[1]:copy(x)
+      new_x[2]:copy(x)
+      new_x[3]:copy(x)
+      x = new_x
+      i2rgb = true
+   end
    if reconstruct.is_rgb(model) then
-      return reconstruct.image_rgb(model, x,
-				   reconstruct.offset_size(model), block_size)
+      x = reconstruct.image_rgb(model, x,
+				reconstruct.offset_size(model), block_size)
    else
-      return reconstruct.image_y(model, x,
-				 reconstruct.offset_size(model), block_size)
+      x = reconstruct.image_y(model, x,
+			      reconstruct.offset_size(model), block_size)
+   end
+   if i2rgb then
+      x = image.rgb2y(x)
    end
+   return x
 end
 function reconstruct.scale(model, scale, x, block_size)
+   local i2rgb = false
+   if x:size(1) == 1 then
+      local new_x = torch.Tensor(3, x:size(2), x:size(3))
+      new_x[1]:copy(x)
+      new_x[2]:copy(x)
+      new_x[3]:copy(x)
+      x = new_x
+      i2rgb = true
+   end
    if reconstruct.is_rgb(model) then
-      return reconstruct.scale_rgb(model, scale, x,
-				   reconstruct.offset_size(model), block_size)
+      x = reconstruct.scale_rgb(model, scale, x,
+				reconstruct.offset_size(model), block_size)
    else
-      return reconstruct.scale_y(model, scale, x,
-				 reconstruct.offset_size(model), block_size)
+      x = reconstruct.scale_y(model, scale, x,
+			      reconstruct.offset_size(model), block_size)
+   end
+   if i2rgb then
+      x = image.rgb2y(x)
    end
+   return x
 end
 local function tta(f, model, x, block_size)
    local average = nil

+ 17 - 11
waifu2x.lua

@@ -6,6 +6,7 @@ require 'w2nn'
 local iproc = require 'iproc'
 local reconstruct = require 'reconstruct'
 local image_loader = require 'image_loader'
+local alpha_util = require 'alpha_util'
 
 torch.setdefaulttensortype('torch.FloatTensor')
 
@@ -14,6 +15,7 @@ local function convert_image(opt)
    local new_x = nil
    local t = sys.clock()
    local scale_f, image_f
+
    if opt.tta == 1 then
       scale_f = reconstruct.scale_tta
       image_f = reconstruct.image_tta
@@ -34,13 +36,16 @@ local function convert_image(opt)
 	 error("Load Error: " .. model_path)
       end
       new_x = image_f(model, x, opt.crop_size)
+      new_x = alpha_util.composite(new_x, alpha)
    elseif opt.m == "scale" then
       local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
       local model = torch.load(model_path, "ascii")
       if not model then
 	 error("Load Error: " .. model_path)
       end
+      x = alpha_util.make_border(x, alpha, reconstruct.offset_size(model))
       new_x = scale_f(model, opt.scale, x, opt.crop_size)
+      new_x = alpha_util.composite(new_x, alpha, model)
    elseif opt.m == "noise_scale" then
       local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
       local noise_model = torch.load(noise_model_path, "ascii")
@@ -53,15 +58,14 @@ local function convert_image(opt)
       if not scale_model then
 	 error("Load Error: " .. scale_model_path)
       end
+      x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
       x = image_f(noise_model, x, opt.crop_size)
       new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
+      new_x = alpha_util.composite(new_x, alpha, scale_model)
    else
       error("undefined method:" .. opt.method)
    end
-   if opt.white_noise == 1 then
-      new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
-   end
-   image_loader.save_png(opt.o, new_x, alpha, opt.depth)
+   image_loader.save_png(opt.o, new_x, opt.depth)
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 local function convert_frames(opt)
@@ -128,23 +132,27 @@ local function convert_frames(opt)
 	 local new_x = nil
 	 if opt.m == "noise" and opt.noise_level == 1 then
 	    new_x = image_f(noise1_model, x, opt.crop_size)
+	    new_x = alpha_util.composite(new_x, alpha)
 	 elseif opt.m == "noise" and opt.noise_level == 2 then
 	    new_x = image_func(noise2_model, x, opt.crop_size)
+	    new_x = alpha_util.composite(new_x, alpha)
 	 elseif opt.m == "scale" then
+	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
 	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
+	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 elseif opt.m == "noise_scale" and opt.noise_level == 1 then
+	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
 	    x = image_f(noise1_model, x, opt.crop_size)
 	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
+	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 elseif opt.m == "noise_scale" and opt.noise_level == 2 then
+	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
 	    x = image_f(noise2_model, x, opt.crop_size)
 	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
+	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 else
 	    error("undefined method:" .. opt.method)
 	 end
-	 if opt.white_noise == 1 then
-	    new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
-	 end
-
 	 local output = nil
 	 if opt.o == "(auto)" then
 	    local name = path.basename(lines[i])
@@ -154,7 +162,7 @@ local function convert_frames(opt)
 	 else
 	    output = string.format(opt.o, i)
 	 end
-	 image_loader.save_png(output, new_x, alpha, opt.depth)
+	 image_loader.save_png(output, new_x, opt.depth)
 	 xlua.progress(i, #lines)
 	 if i % 10 == 0 then
 	    collectgarbage()
@@ -182,8 +190,6 @@ local function waifu2x()
    cmd:option("-resume", 0, "skip existing files (0|1)")
    cmd:option("-thread", -1, "number of CPU threads")
    cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
-   cmd:option("-white_noise", 0, 'adding white noise to output image (0|1)')
-   cmd:option("-white_noise_std", 0.0055, 'standard division of white noise')
    
    local opt = cmd:parse(arg)
    if opt.thread > 0 then

+ 56 - 14
web.lua

@@ -11,6 +11,7 @@ local md5 = require 'md5'
 local iproc = require 'iproc'
 local reconstruct = require 'reconstruct'
 local image_loader = require 'image_loader'
+local alpha_util = require 'alpha_util'
 
 -- Notes:  turbo and xlua has different implementation of string:split().
 --         Therefore, string:split() has conflict issue.
@@ -104,14 +105,36 @@ local function cleanup_model(model)
       w2nn.cleanup_model(model) -- release GPU memory
    end
 end
-local function convert(x, options)
+local function convert(x, alpha, options)
    local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
+   local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
+   local alpha_orig = alpha
+
+   if path.exists(alpha_cache_file) then
+      alpha = image_loader.load_float(alpha_cache_file)
+      if alpha:dim() == 2 then
+	 alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
+      end
+      if alpha:size(1) == 3 then
+	 alpha = image.rgb2y(alpha)
+      end
+   end
    if path.exists(cache_file) then
-      return image.load(cache_file)
+      x = image_loader.load_float(cache_file)
+      return x, alpha
    else
       if options.style == "art" then
+	 if options.border then
+	    x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
+	 end
 	 if options.method == "scale" then
 	    x = reconstruct.scale(art_scale2_model, 2.0, x)
+	    if alpha then
+	       if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
+		  alpha = reconstruct.scale(art_scale2_model, 2.0, alpha)
+		  image_loader.save_png(alpha_cache_file, alpha)
+	       end
+	    end
 	    cleanup_model(art_scale2_model)
 	 elseif options.method == "noise1" then
 	    x = reconstruct.image(art_noise1_model, x)
@@ -121,8 +144,17 @@ local function convert(x, options)
 	    cleanup_model(art_noise2_model)
 	 end
       else --[[photo
+	 if options.border then
+	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
+	 end
 	 if options.method == "scale" then
 	    x = reconstruct.scale(photo_scale2_model, 2.0, x)
+	    if alpha then
+	       if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
+		  alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha)
+		  image_loader.save_png(alpha_cache_file, alpha)
+	       end
+	    end
 	    cleanup_model(photo_scale2_model)
 	 elseif options.method == "noise1" then
 	    x = reconstruct.image(photo_noise1_model, x)
@@ -133,8 +165,9 @@ local function convert(x, options)
 	 end
       --]]
       end
-      image.save(cache_file, x)
-      return x
+      image_loader.save_png(cache_file, x)
+
+      return x, alpha
    end
 end
 local function client_disconnected(handler)
@@ -154,7 +187,6 @@ function APIHandler:post()
    local x, alpha, blob = get_image(self)
    local scale = tonumber(self:get_argument("scale", "0"))
    local noise = tonumber(self:get_argument("noise", "0"))
-   local white_noise = tonumber(self:get_argument("white_noise", "0"))
    local style = self:get_argument("style", "art")
    local download = (self:get_argument("download", "")):len()
 
@@ -164,29 +196,39 @@ function APIHandler:post()
    if x and valid_size(x, scale) then
       if (noise ~= 0 or scale ~= 0) then
 	 local hash = md5.sumhexa(blob)
+	 local alpha_prefix = style .. "_" .. hash .. "_alpha"
+	 local border = false
+	 if scale ~= 0 and alpha then
+	    border = true
+	 end
 	 if noise == 1 then
-	    x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash})
+	    x = convert(x, alpha, {method = "noise1", style = style,
+				   prefix = style .. "_noise1_" .. hash,
+				   alpha_prefix = alpha_prefix, border = border})
+	    border = false
 	 elseif noise == 2 then
-	    x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash})
+	    x = convert(x, alpha, {method = "noise2", style = style,
+				   prefix = style .. "_noise2_" .. hash, 
+				   alpha_prefix = alpha_prefix, border = border})
+	    border = false
 	 end
 	 if scale == 1 or scale == 2 then
+	    local prefix
 	    if noise == 1 then
-	       x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash})
+	       prefix = style .. "_noise1_scale_" .. hash
 	    elseif noise == 2 then
-	       x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash})
+	       prefix = style .. "_noise2_scale_" .. hash
 	    else
-	       x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash})
+	       prefix = style .. "_scale_" .. hash
 	    end
+	    x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix, alpha_prefix = alpha_prefix, border = border})
 	    if scale == 1 then
 	       x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
 	    end
 	 end
-	 if white_noise == 1 then
-	    x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0})
-	 end
       end
       local name = uuid() .. ".png"
-      local blob = image_loader.encode_png(x, alpha)
+      local blob = image_loader.encode_png(alpha_util.composite(x, alpha))
       self:set_header("Content-Disposition", string.format('filename="%s"', name))
       self:set_header("Content-Length", string.format("%d", #blob))
       if download > 0 then