Browse Source

Change cache storage from HDD to RAM

nagadomi 8 năm trước cách đây
mục cha
commit
2280f388ff
1 tập tin đã thay đổi với 66 bổ sung17 xóa
  1. 66 17
      web.lua

+ 66 - 17
web.lua

@@ -12,6 +12,7 @@ local iproc = require 'iproc'
 local reconstruct = require 'reconstruct'
 local image_loader = require 'image_loader'
 local alpha_util = require 'alpha_util'
+local compression = require 'compression'
 local gm = require 'graphicsmagick'
 
 -- Note:  turbo and xlua has different implementation of string:split().
@@ -34,6 +35,7 @@ cmd:option("-curl_request_timeout", 60, "request_timeout for curl")
 cmd:option("-curl_connect_timeout", 60, "connect_timeout for curl")
 cmd:option("-curl_max_redirects", 2, "max_redirects for curl")
 cmd:option("-max_body_size", 5 * 1024 * 1024, "maximum allowed size for uploaded files")
+cmd:option("-cache_max", 200, "number of cached images on RAM")
 
 local opt = cmd:parse(arg)
 cutorch.setDevice(opt.gpu)
@@ -75,6 +77,7 @@ local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could us
 local CACHE_DIR = path.join(ROOT, "cache")
 local MAX_NOISE_IMAGE = opt.max_pixels
 local MAX_SCALE_IMAGE = (math.sqrt(opt.max_pixels) / 2)^2
+local PNG_DEPTH = 8
 local CURL_OPTIONS = {
    request_timeout = opt.curl_request_timeout,
    connect_timeout = opt.curl_connect_timeout,
@@ -167,26 +170,73 @@ local function cleanup_model(model)
       model:clearState() -- release GPU memory
    end
 end
+
+-- cache
+local g_cache = {}
+local function cache_count()
+   local count = 0
+   for _ in pairs(g_cache) do
+      count = count + 1
+   end
+   return count
+end
+local function cache_remove_old()
+   local old_time = nil
+   local old_key = nil
+   for k, v in pairs(g_cache) do
+      if old_time == nil or old_time > v.updated_at then
+	 old_key = k
+	 old_time = v.updated_at
+      end
+   end
+   if old_key then
+      g_cache[old_key] = nil
+   end
+end
+local function cache_compress(raw_image)
+   if raw_image then
+      compressed_image = compression.compress(iproc.float2byte(raw_image))
+      return compressed_image
+   else
+      return nil
+   end
+end
+local function cache_decompress(compressed_image)
+   if compressed_image then
+      local raw_image = compression.decompress(compressed_image)
+      return iproc.byte2float(raw_image)
+   else
+      return nil
+   end
+end
+local function cache_get(filename)
+   local cache = g_cache[filename]
+   if cache then
+      return {image = cache_decompress(cache.image),
+	      alpha = cache_decompress(cache.alpha)}
+   else
+      return nil
+   end
+end
+local function cache_put(filename, image, alpha)
+   g_cache[filename] = {image = cache_compress(image),
+			alpha = cache_compress(alpha),
+			updated_at = os.time()};
+   local count = cache_count(g_cache)
+   if count > opt.cache_max then
+      cache_remove_old()
+   end
+end
 local function convert(x, meta, options)
    local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
-   local alpha_cache_file = path.join(CACHE_DIR, options.alpha_prefix .. ".png")
    local alpha = meta.alpha
    local alpha_orig = alpha
+   local cache = cache_get(cache_file)
 
-   if path.exists(alpha_cache_file) then
-      alpha = image_loader.load_float(alpha_cache_file)
-      if alpha:dim() == 2 then
-	 alpha = alpha:reshape(1, alpha:size(1), alpha:size(2))
-      end
-      if alpha:size(1) == 3 then
-	 alpha = image.rgb2y(alpha)
-      end
-   end
-   if path.exists(cache_file) then
-      x = image_loader.load_float(cache_file)
+   if cache then
       meta = tablex.copy(meta)
-      meta.alpha = alpha
-      return x, meta
+      meta.alpha = cache.alpha
+      return cache.image, meta
    else
       local model = nil
       if options.style == "art" then
@@ -209,7 +259,6 @@ local function convert(x, meta, options)
 	    if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
 	       alpha = reconstruct.scale(model.scale, 2.0, alpha,
 					 opt.crop_size, opt.batch_size)
-	       image_loader.save_png(alpha_cache_file, alpha)
 	       cleanup_model(model.scale)
 	    end
 	 end
@@ -223,7 +272,7 @@ local function convert(x, meta, options)
 				   x, opt.crop_size, opt.batch_size)
 	 cleanup_model(model[options.method])
       end
-      image_loader.save_png(cache_file, x)
+      cache_put(cache_file, x, alpha)
       meta = tablex.copy(meta)
       meta.alpha = alpha
       return x, meta
@@ -321,7 +370,7 @@ function APIHandler:post()
 	 name = uuid() .. ".png"
       end
       local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha),
-					   tablex.update({depth = 8, inplace = true}, meta))
+					   tablex.update({depth = PNG_DEPTH, inplace = true}, meta))
 
       self:set_header("Content-Length", string.format("%d", #blob))
       if download > 0 then