Browse Source

Add support for url cache in web.lua

nagadomi 9 years ago
parent
commit
539941c234
3 changed files with 43 additions and 30 deletions
  1. 1 0
      .gitignore
  2. 4 4
      lib/image_loader.lua
  3. 38 26
      web.lua

+ 1 - 0
.gitignore

@@ -1,6 +1,7 @@
 *~
 work/
 cache/*.png
+cache/url_*
 data/
 !data/.gitkeep
 

+ 4 - 4
lib/image_loader.lua

@@ -9,7 +9,7 @@ function image_loader.decode_float(blob)
    if im then
       im = im:float():div(255)
    end
-   return im, alpha
+   return im, alpha, blob
 end
 function image_loader.encode_png(rgb, alpha)
    if rgb:type() == "torch.ByteTensor" then
@@ -74,14 +74,14 @@ function image_loader.decode_byte(blob)
       else
 	 im = im:toTensor('byte', 'RGB', 'DHW')
       end
-      return {im, alpha}
+      return {im, alpha, blob}
    end
    load_image()
    local state, ret = pcall(load_image)
    if state then
-      return ret[1], ret[2]
+      return ret[1], ret[2], ret[3]
    else
-      return nil
+      return nil, nil, nil
    end
 end
 function image_loader.load_float(file)

+ 38 - 26
web.lua

@@ -55,20 +55,25 @@ local function valid_size(x, scale)
    end
 end
 
-local function get_image(req)
-   local file = req:get_argument("file", "")
-   local url = req:get_argument("url", "")
-   local blob = nil
-   local img = nil
-   local alpha = nil
-   if file and file:len() > 0 then
-      blob = file
-      img, alpha = image_loader.decode_float(blob)
-   elseif url and url:len() > 0 then
+local function apply_denoise1(x)
+   return reconstruct.image(noise1_model, x)
+end
+local function apply_denoise2(x)
+   return reconstruct.image(noise2_model, x)
+end
+local function apply_scale2x(x)
+   return reconstruct.scale(scale20_model, 2.0, x)
+end
+local function cache_url(url)
+   local hash = md5.sumhexa(url)
+   local cache_file = path.join(CACHE_DIR, "url_" .. hash)
+   if path.exists(cache_file) then
+      return image_loader.load_float(cache_file)
+   else
       local res = coroutine.yield(
 	 turbo.async.HTTPClient({verify_ca=false},
-				nil,
-				CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
+	    nil,
+	    CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
       )
       if res.code == 200 then
 	 local content_type = res.headers:get("Content-Type", true)
@@ -76,22 +81,15 @@ local function get_image(req)
 	    content_type = content_type[1]
 	 end
 	 if content_type and content_type:find("image") then
-	    blob = res.body
-	    img, alpha = image_loader.decode_float(blob)
+	    local fp = io.open(cache_file, "wb")
+	    local blob = res.body
+	    fp:write(blob)
+	    fp:close()
+	    return image_loader.decode_float(blob)
 	 end
       end
    end
-   return img, blob, alpha
-end
-
-local function apply_denoise1(x)
-   return reconstruct.image(noise1_model, x)
-end
-local function apply_denoise2(x)
-   return reconstruct.image(noise2_model, x)
-end
-local function apply_scale2x(x)
-   return reconstruct.scale(scale20_model, 2.0, x)
+   return nil, nil, nil
 end
 local function cache_do(cache, x, func)
    if path.exists(cache) then
@@ -102,6 +100,20 @@ local function cache_do(cache, x, func)
       return x
    end
 end
+local function get_image(req)
+   local file = req:get_argument("file", "")
+   local url = req:get_argument("url", "")
+   local blob = nil
+   local img = nil
+   local alpha = nil
+   if file and file:len() > 0 then
+      blob = file
+      return image_loader.decode_float(blob)
+   elseif url and url:len() > 0 then
+      return cache_url(url)
+   end
+   return nil, nil, nil
+end
 
 local function client_disconnected(handler)
    return not(handler.request and
@@ -117,7 +129,7 @@ function APIHandler:post()
       self:write("client disconnected")
       return
    end
-   local x, src, alpha = get_image(self)
+   local x, alpha, src = get_image(self)
    local scale = tonumber(self:get_argument("scale", "0"))
    local noise = tonumber(self:get_argument("noise", "0"))
    if x and valid_size(x, scale) then