Prechádzať zdrojové kódy

Merge branch 'master' of github.com:nagadomi/waifu2x into rgb

nagadomi 10 rokov pred
rodič
commit
628bd971c9
4 zmenil súbory, kde vykonal 91 pridanie a 60 odobranie
  1. 4 1
      README.md
  2. 43 15
      lib/image_loader.lua
  3. 37 38
      waifu2x.lua
  4. 7 6
      web.lua

+ 4 - 1
README.md

@@ -1,6 +1,6 @@
 # waifu2x
 
-Image Super-Resolution for anime/fan-art using Deep Convolutional Neural Networks.
+Image Super-Resolution for anime-style-art using Deep Convolutional Neural Networks.
 
 Demo-Application can be found at http://waifu2x.udp.jp/ .
 
@@ -20,6 +20,9 @@ waifu2x is inspired by SRCNN [1]. 2D character picture (HatsuneMiku) is licensed
 ## Public AMI
 (maintenance)
 
+## Third Party Software
+[Third-Party](https://github.com/nagadomi/waifu2x/wiki/Third-Party)
+
 ## Dependencies
 
 ### Hardware

+ 43 - 15
lib/image_loader.lua

@@ -1,44 +1,72 @@
 local gm = require 'graphicsmagick'
+local ffi = require 'ffi'
 require 'pl'
 
 local image_loader = {}
 
 function image_loader.decode_float(blob)
-   local im = image_loader.decode_byte(blob)
+   local im, alpha = image_loader.decode_byte(blob)
    if im then
       im = im:float():div(255)
    end
-   return im
+   return im, alpha
 end
-function image_loader.encode_png(tensor)
-   local im = gm.Image(tensor, "RGB", "DHW")
-   im:format("png")
-   return im:toBlob()
+function image_loader.encode_png(rgb, alpha)
+   if rgb:type() == "torch.ByteTensor" then
+      error("expect FloatTensor")
+   end
+   if alpha then
+      if not (alpha:size(2) == rgb:size(2) and  alpha:size(3) == rgb:size(3)) then
+	 alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
+      end
+      local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
+      rgba[1]:copy(rgb[1])
+      rgba[2]:copy(rgb[2])
+      rgba[3]:copy(rgb[3])
+      rgba[4]:copy(alpha)
+      local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
+      im:format("png")
+      return im:toBlob()
+   else
+      local im = gm.Image(rgb, "RGB", "DHW")
+      im:format("png")
+      return im:toBlob()
+   end
+end
+function image_loader.save_png(filename, rgb, alpha)
+   local blob, len = image_loader.encode_png(rgb, alpha)
+   local fp = io.open(filename, "wb")
+   fp:write(ffi.string(blob, len))
+   fp:close()
+   return true
 end
 function image_loader.decode_byte(blob)
    local load_image = function()
       local im = gm.Image()
+      local alpha = nil
+      
       im:fromBlob(blob, #blob)
       -- FIXME: How to detect that a image has an alpha channel?
       if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
-	 -- merge alpha channel
+	 -- split alpha channel
 	 im = im:toTensor('float', 'RGBA', 'DHW')
-	 local w2 = im[4]
-	 local w1 = im[4] * -1 + 1
+	 local sum_alpha = (im[4] - 1):sum()
+	 if sum_alpha > 0 or sum_alpha < 0 then
+	    alpha = im[4]:reshape(1, im:size(2), im:size(3))
+	 end
 	 local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
-	 -- apply the white background
-	 new_im[1]:copy(im[1]):cmul(w2):add(w1)
-	 new_im[2]:copy(im[2]):cmul(w2):add(w1)
-	 new_im[3]:copy(im[3]):cmul(w2):add(w1)
+	 new_im[1]:copy(im[1])
+	 new_im[2]:copy(im[2])
+	 new_im[3]:copy(im[3])
 	 im = new_im:mul(255):byte()
       else
 	 im = im:toTensor('byte', 'RGB', 'DHW')
       end
-      return im
+      return {im, alpha}
    end
    local state, ret = pcall(load_image)
    if state then
-      return ret
+      return ret[1], ret[2]
    else
       return nil
    end

+ 37 - 38
waifu2x.lua

@@ -6,13 +6,12 @@ require './lib/LeakyReLU'
 local iproc = require './lib/iproc'
 local reconstruct = require './lib/reconstruct'
 local image_loader = require './lib/image_loader'
-
 local BLOCK_OFFSET = 7
 
 torch.setdefaulttensortype('torch.FloatTensor')
 
 local function convert_image(opt)
-   local x = image_loader.load_float(opt.i)
+   local x, alpha = image_loader.load_float(opt.i)
    local new_x = nil
    local t = sys.clock()
    if opt.o == "(auto)" then
@@ -39,7 +38,7 @@ local function convert_image(opt)
    else
       error("undefined method:" .. opt.method)
    end
-   image.save(opt.o, new_x)
+   image_loader.save_png(opt.o, new_x, alpha)
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 local function convert_frames(opt)
@@ -59,41 +58,41 @@ local function convert_frames(opt)
    end
    fp:close()
    for i = 1, #lines do
-	if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
-	      local x = image_loader.load_float(lines[i])
-	      local new_x = nil
-	      if opt.m == "noise" and opt.noise_level == 1 then
-		 new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
-	      elseif opt.m == "noise" and opt.noise_level == 2 then
-		 new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
-	      elseif opt.m == "scale" then
-		 new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
-	      elseif opt.m == "noise_scale" and opt.noise_level == 1 then
-		 x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
-		 new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
-	      elseif opt.m == "noise_scale" and opt.noise_level == 2 then
-		 x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
-		 new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
-	      else
-		 error("undefined method:" .. opt.method)
-	      end
-	      local output = nil
-	      if opt.o == "(auto)" then
-		 local name = path.basename(lines[i])
-		 local e = path.extension(name)
-		 local base = name:sub(0, name:len() - e:len())
-		 output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
-	      else
-		 output = string.format(opt.o, i)
-	      end
-	      image.save(output, new_x)
-	      xlua.progress(i, #lines)
-	      if i % 10 == 0 then
-		 collectgarbage()
-	      end
-	else
-           xlua.progress(i, #lines)
-	end
+      if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
+	 local x, alpha = image_loader.load_float(lines[i])
+	 local new_x = nil
+	 if opt.m == "noise" and opt.noise_level == 1 then
+	    new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
+	 elseif opt.m == "noise" and opt.noise_level == 2 then
+	    new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
+	 elseif opt.m == "scale" then
+	    new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+	 elseif opt.m == "noise_scale" and opt.noise_level == 1 then
+	    x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
+	    new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+	 elseif opt.m == "noise_scale" and opt.noise_level == 2 then
+	    x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
+	    new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
+	 else
+	    error("undefined method:" .. opt.method)
+	 end
+	 local output = nil
+	 if opt.o == "(auto)" then
+	    local name = path.basename(lines[i])
+	    local e = path.extension(name)
+	    local base = name:sub(0, name:len() - e:len())
+	    output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
+	 else
+	    output = string.format(opt.o, i)
+	 end
+	 image_loader.save_png(output, new_x, alpha)
+	 xlua.progress(i, #lines)
+	 if i % 10 == 0 then
+	    collectgarbage()
+	 end
+      else
+	 xlua.progress(i, #lines)
+      end
    end
 end
 

+ 7 - 6
web.lua

@@ -1,3 +1,4 @@
+_G.TURBO_SSL = true
 local turbo = require 'turbo'
 local uuid = require 'uuid'
 local ffi = require 'ffi'
@@ -46,10 +47,10 @@ local function get_image(req)
    local url = req:get_argument("url", "")
    local blob = nil
    local img = nil
-   
+   local alpha = nil
    if file and file:len() > 0 then
       blob = file
-      img = image_loader.decode_float(blob)
+      img, alpha = image_loader.decode_float(blob)
    elseif url and url:len() > 0 then
       local res = coroutine.yield(
 	 turbo.async.HTTPClient({verify_ca=false},
@@ -63,11 +64,11 @@ local function get_image(req)
 	 end
 	 if content_type and content_type:find("image") then
 	    blob = res.body
-	    img = image_loader.decode_float(blob)
+	    img, alpha = image_loader.decode_float(blob)
 	 end
       end
    end
-   return img, blob
+   return img, blob, alpha
 end
 
 local function apply_denoise1(x)
@@ -103,7 +104,7 @@ function APIHandler:post()
       self:write("client disconnected")
       return
    end
-   local x, src = get_image(self)
+   local x, src, alpha = get_image(self)
    local scale = tonumber(self:get_argument("scale", "0"))
    local noise = tonumber(self:get_argument("noise", "0"))
    if x and valid_size(x, scale) then
@@ -150,7 +151,7 @@ function APIHandler:post()
 	 end
       end
       local name = uuid() .. ".png"
-      local blob, len = image_loader.encode_png(x)
+      local blob, len = image_loader.encode_png(x, alpha)
       
       self:set_header("Content-Disposition", string.format('filename="%s"', name))
       self:set_header("Content-Type", "image/png")