소스 검색

Merge branch 'master' of github.com:nagadomi/waifu2x into dev

nagadomi 9 년 전
부모
커밋
0b130b797c
1개의 변경된 파일44개의 추가작업 그리고 13개의 파일을 삭제
  1. 44 13
      web.lua

+ 44 - 13
web.lua

@@ -92,14 +92,25 @@ local function cache_url(url)
    return nil, nil, nil
 end
 local function get_image(req)
-   local file = req:get_argument("file", "")
+   local file_info = req:get_arguments("file")
    local url = req:get_argument("url", "")
+   local file = nil
+   local filename = nil
+   if file_info and #file_info == 1 then
+      file = file_info[1][1]
+      local disp = file_info[1]["content-disposition"]
+      if disp and disp["filename"] then
+	 filename = path.basename(disp["filename"])
+      end
+   end
    if file and file:len() > 0 then
-      return image_loader.decode_float(file)
+      local x, alpha, blob = image_loader.decode_float(file)
+      return x, alpha, blob, filename
    elseif url and url:len() > 0 then
-      return cache_url(url)
+      local x, alpha, blob = cache_url(url)
+      return x, alpha, blob, filename
    end
-   return nil, nil, nil
+   return nil, nil, nil, nil
 end
 local function cleanup_model(model)
    if CLEANUP_MODEL then
@@ -176,6 +187,15 @@ local function client_disconnected(handler)
 		 handler.request.connection.stream and
 		 (not handler.request.connection.stream:closed()))
 end
+local function make_output_filename(filename, mode)
+   local e = path.extension(filename)
+   local base = filename:sub(0, filename:len() - e:len())
+   if mode then
+      return base .. "_waifu2x_" .. mode .. ".png"
+   else
+      return base .. ".png"
+   end
+end
 
 local APIHandler = class("APIHandler", turbo.web.RequestHandler)
 function APIHandler:post()
@@ -184,7 +204,7 @@ function APIHandler:post()
       self:write("client disconnected")
       return
    end
-   local x, alpha, blob = get_image(self)
+   local x, alpha, blob, filename = get_image(self)
    local scale = tonumber(self:get_argument("scale", "0"))
    local noise = tonumber(self:get_argument("noise", "0"))
    local style = self:get_argument("style", "art")
@@ -194,6 +214,7 @@ function APIHandler:post()
       style = "photo" -- style must be art or photo
    end
    if x and valid_size(x, scale) then
+      local prefix = nil
       if (noise ~= 0 or scale ~= 0) then
 	 local hash = md5.sumhexa(blob)
 	 local alpha_prefix = style .. "_" .. hash .. "_alpha"
@@ -202,32 +223,42 @@ function APIHandler:post()
 	    border = true
 	 end
 	 if noise == 1 then
+	    prefix = style .. "_noise1_"
 	    x = convert(x, alpha, {method = "noise1", style = style,
-				   prefix = style .. "_noise1_" .. hash,
+				   prefix = prefix .. hash,
 				   alpha_prefix = alpha_prefix, border = border})
 	    border = false
 	 elseif noise == 2 then
+	    prefix = style .. "_noise1_"
 	    x = convert(x, alpha, {method = "noise2", style = style,
-				   prefix = style .. "_noise2_" .. hash, 
+				   prefix = prefix .. hash, 
 				   alpha_prefix = alpha_prefix, border = border})
 	    border = false
 	 end
 	 if scale == 1 or scale == 2 then
-	    local prefix
 	    if noise == 1 then
-	       prefix = style .. "_noise1_scale_" .. hash
+	       prefix = style .. "_noise1_scale_"
 	    elseif noise == 2 then
-	       prefix = style .. "_noise2_scale_" .. hash
+	       prefix = style .. "_noise2_scale_"
 	    else
-	       prefix = style .. "_scale_" .. hash
+	       prefix = style .. "_scale_"
 	    end
-	    x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix, alpha_prefix = alpha_prefix, border = border})
+	    x, alpha = convert(x, alpha, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
 	    if scale == 1 then
 	       x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
 	    end
 	 end
       end
-      local name = uuid() .. ".png"
+      local name = nil
+      if filename then 
+	 if prefix then
+	    name = make_output_filename(filename, prefix:sub(0, prefix:len()-1))
+	 else
+	    name = make_output_filename(filename, nil)
+	 end
+      else
+	 name = uuid() .. ".png"
+      end
       local blob = image_loader.encode_png(alpha_util.composite(x, alpha))
 
       self:set_header("Content-Length", string.format("%d", #blob))