Kaynağa Gözat

Add support for tta_level=1; Add support for TTA to web.lua

nagadomi 9 yıl önce
ebeveyn
işleme
25e293202a
2 değiştirilmiş dosya ile 41 ekleme ve 21 silme
  1. 4 1
      lib/reconstruct.lua
  2. 37 20
      web.lua

+ 4 - 1
lib/reconstruct.lua

@@ -269,7 +269,10 @@ local augmented_patterns = {
    }
 }
 local function get_augmented_patterns(n)
-   if n == 2 then
+   if n == 1 then
+      -- no tta
+      return {augmented_patterns[1]}
+   elseif n == 2 then
       return {augmented_patterns[1], augmented_patterns[5]}
    elseif n == 4 then
       return {augmented_patterns[1], augmented_patterns[5],

+ 37 - 20
web.lua

@@ -63,6 +63,7 @@ local CURL_OPTIONS = {
    max_redirects = 2
 }
 local CURL_MAX_SIZE = 3 * 1024 * 1024
+local TTA_SUPPORT = false
 
 local function valid_size(x, scale)
    if scale == 0 then
@@ -151,8 +152,8 @@ local function convert(x, meta, options)
 	    x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
 	 end
 	 if options.method == "scale" then
-	    x = reconstruct.scale(art_scale2_model, 2.0, x,
-				  opt.crop_size, opt.batch_size)
+	    x = reconstruct.scale_tta(art_scale2_model, options.tta_level, 2.0, x,
+				      opt.crop_size, opt.batch_size)
 	    if alpha then
 	       if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
 		  alpha = reconstruct.scale(art_scale2_model, 2.0, alpha,
@@ -162,13 +163,16 @@ local function convert(x, meta, options)
 	    end
 	    cleanup_model(art_scale2_model)
 	 elseif options.method == "noise1" then
-	    x = reconstruct.image(art_noise1_model, x, opt.crop_size, opt.batch_size)
+	    x = reconstruct.image_tta(art_noise1_model, options.tta_level,
+				      x, opt.crop_size, opt.batch_size)
 	    cleanup_model(art_noise1_model)
 	 elseif options.method == "noise2" then
-	    x = reconstruct.image(art_noise2_model, x, opt.crop_size, opt.batch_size)
+	    x = reconstruct.image_tta(art_noise2_model, options.tta_level,
+				      x, opt.crop_size, opt.batch_size)
 	    cleanup_model(art_noise2_model)
 	 elseif options.method == "noise3" then
-	    x = reconstruct.image(art_noise3_model, x, opt.crop_size, opt.batch_size)
+	    x = reconstruct.image_tta(art_noise3_model, options.tta_level,
+				      x, opt.crop_size, opt.batch_size)
 	    cleanup_model(art_noise3_model)
 	 end
       else -- photo
@@ -176,7 +180,7 @@ local function convert(x, meta, options)
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
 	 end
 	 if options.method == "scale" then
-	    x = reconstruct.scale(photo_scale2_model, 2.0, x,
+	    x = reconstruct.scale_tta(photo_scale2_model, options.tta_level, 2.0, x,
 				  opt.crop_size, opt.batch_size)
 	    if alpha then
 	       if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
@@ -187,13 +191,16 @@ local function convert(x, meta, options)
 	    end
 	    cleanup_model(photo_scale2_model)
 	 elseif options.method == "noise1" then
-	    x = reconstruct.image(photo_noise1_model, x, opt.crop_size, opt.batch_size)
+	    x = reconstruct.image_tta(photo_noise1_model, options.tta_level,
+				      x, opt.crop_size, opt.batch_size)
 	    cleanup_model(photo_noise1_model)
 	 elseif options.method == "noise2" then
-	    x = reconstruct.image(photo_noise2_model, x, opt.crop_size, opt.batch_size)
+	    x = reconstruct.image_tta(photo_noise2_model, options.tta_level,
+				      x, opt.crop_size, opt.batch_size)
 	    cleanup_model(photo_noise2_model)
 	 elseif options.method == "noise3" then
-	    x = reconstruct.image(photo_noise3_model, x, opt.crop_size, opt.batch_size)
+	    x = reconstruct.image_tta(photo_noise3_model, options.tta_level,
+				      x, opt.crop_size, opt.batch_size)
 	    cleanup_model(photo_noise3_model)
 	 end
       end
@@ -230,9 +237,18 @@ function APIHandler:post()
    local x, meta, filename = get_image(self)
    local scale = tonumber(self:get_argument("scale", "0"))
    local noise = tonumber(self:get_argument("noise", "0"))
+   local tta_level = tonumber(self:get_argument("noise", "1"))
    local style = self:get_argument("style", "art")
    local download = (self:get_argument("download", "")):len()
 
+   if not TTA_SUPPORT then
+      tta_level = 1 -- disable TTA mode
+   else
+      if not (tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
+	 tta_level = 1
+      end
+   end
+
    if style ~= "art" then
       style = "photo" -- style must be art or photo
    end
@@ -246,35 +262,36 @@ function APIHandler:post()
 	    border = true
 	 end
 	 if noise == 1 then
-	    prefix = style .. "_noise1_"
-	    x = convert(x, meta, {method = "noise1", style = style,
+	    prefix = style .. "_noise1_tta_" .. tta_level .. "_"
+	    x = convert(x, meta, {method = "noise1", style = style, tta_level = tta_level,
 				  prefix = prefix .. hash,
 				  alpha_prefix = alpha_prefix, border = border})
 	    border = false
 	 elseif noise == 2 then
-	    prefix = style .. "_noise2_"
-	    x = convert(x, meta, {method = "noise2", style = style,
+	    prefix = style .. "_noise2_tta_" .. tta_level .. "_"
+	    x = convert(x, meta, {method = "noise2", style = style, tta_level = tta_level,
 				  prefix = prefix .. hash, 
 				  alpha_prefix = alpha_prefix, border = border})
 	    border = false
 	 elseif noise == 3 then
-	    prefix = style .. "_noise3_"
-	    x = convert(x, meta, {method = "noise3", style = style,
+	    prefix = style .. "_noise3_tta_" .. tta_level .. "_"
+	    x = convert(x, meta, {method = "noise3", style = style, tta_level = tta_level,
 				  prefix = prefix .. hash, 
 				  alpha_prefix = alpha_prefix, border = border})
 	    border = false
 	 end
 	 if scale == 1 or scale == 2 then
 	    if noise == 1 then
-	       prefix = style .. "_noise1_scale_"
+	       prefix = style .. "_noise1_scale_tta_"  .. tta_level .. "_"
 	    elseif noise == 2 then
-	       prefix = style .. "_noise2_scale_"
+	       prefix = style .. "_noise2_scale_tta_"  .. tta_level .. "_"
 	    elseif noise == 3 then
-	       prefix = style .. "_noise3_scale_"
+	       prefix = style .. "_noise3_scale_tta_" .. tta_level .. "_"
 	    else
-	       prefix = style .. "_scale_"
+	       prefix = style .. "_scale_tta_"  .. tta_level .. "_"
 	    end
-	    x, meta = convert(x, meta, {method = "scale", style = style, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
+	    x, meta = convert(x, meta, {method = "scale", style = style, tta_level = tta_level,
+					prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border})
 	    if scale == 1 then
 	       x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc")
 	    end