Browse Source

Add support for grayscale output

nagadomi 9 năm trước cách đây
mục cha
commit
3a27e122ac
3 tập tin đã thay đổi với 19 bổ sung6 xóa
  1. 10 1
      lib/image_loader.lua
  2. 2 2
      waifu2x.lua
  3. 7 3
      web.lua

+ 10 - 1
lib/image_loader.lua

@@ -38,11 +38,17 @@ function image_loader.encode_png(rgb, options)
    local im
    if rgb:size(1) == 4 then -- RGBA
       im = gm.Image(rgb, "RGBA", "DHW")
+      if options.grayscale then
+	 im:type("GrayscaleMatte")
+      end
    elseif rgb:size(1) == 3 then -- RGB
       im = gm.Image(rgb, "RGB", "DHW")
+      if options.grayscale then
+	 im:type("Grayscale")
+      end
    elseif rgb:size(1) == 1 then -- Y
       im = gm.Image(rgb, "I", "DHW")
-      -- im:colorspace("GRAY") -- it does not work
+      im:type("Grayscale")
    end
    if options.gamma then
       im:gamma(options.gamma)
@@ -74,6 +80,9 @@ function image_loader.decode_float(blob)
 	 meta.gamma = im:gamma()
       end
       local image_type = im:type()
+      if image_type == "Grayscale" or image_type == "GrayscaleMatte" then
+	 meta.grayscale = true
+      end
       if image_type == "TrueColorMatte" or image_type == "GrayscaleMatte" then
 	 -- split alpha channel
 	 im = im:toTensor('float', 'RGBA', 'DHW')

+ 2 - 2
waifu2x.lua

@@ -66,7 +66,7 @@ local function convert_image(opt)
    else
       error("undefined method:" .. opt.method)
    end
-   image_loader.save_png(opt.o, new_x, {depth = opt.depth, inplace = true, gamma = meta.gamma})
+   image_loader.save_png(opt.o, new_x, tablex.update({depth = opt.depth, inplace = true}, meta))
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 local function convert_frames(opt)
@@ -144,7 +144,7 @@ local function convert_frames(opt)
 	    output = string.format(opt.o, i)
 	 end
 	 image_loader.save_png(output, new_x, 
-			       {depth = opt.depth, inplace = true, gamma = meta.gamma})
+			       tablex.update({depth = opt.depth, inplace = true}, meta))
 	 xlua.progress(i, #lines)
 	 if i % 10 == 0 then
 	    collectgarbage()

+ 7 - 3
web.lua

@@ -138,7 +138,9 @@ local function convert(x, meta, options)
    end
    if path.exists(cache_file) then
       x = image_loader.load_float(cache_file)
-      return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
+      meta = tablex.copy(meta)
+      meta.alpha = alpha
+      return x, meta
    else
       if options.style == "art" then
 	 if options.border then
@@ -192,8 +194,10 @@ local function convert(x, meta, options)
 	 end
       end
       image_loader.save_png(cache_file, x)
+      meta = tablex.copy(meta)
+      meta.alpha = alpha
 
-      return x, {alpha = alpha, gamma = meta.gamma, blob = meta.blob}
+      return x, meta
    end
 end
 local function client_disconnected(handler)
@@ -283,7 +287,7 @@ function APIHandler:post()
 	 name = uuid() .. ".png"
       end
       local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
-					   { depth = 8, inplace = true, gamma = meta.gamma})
+					   tablex.update({depth = 8, inplace = true}, meta))
 
       self:set_header("Content-Length", string.format("%d", #blob))
       if download > 0 then