Browse Source

Add enable_tta option in web.lua

nagadomi 6 years ago
parent
commit
208043fd89
1 changed files with 10 additions and 5 deletions
  1. 10 5
      web.lua

+ 10 - 5
web.lua

@@ -26,6 +26,7 @@ cmd:text("waifu2x-api")
 cmd:text("Options:")
 cmd:option("-port", 8812, 'listen port')
 cmd:option("-gpu", 1, 'Device ID')
+cmd:option("-enable_tta", 0, 'enable TTA query(0|1)')
 cmd:option("-crop_size", 128, 'patch size per process')
 cmd:option("-batch_size", 1, 'batch size')
 cmd:option("-thread", -1, 'number of CPU threads')
@@ -48,6 +49,7 @@ if cudnn then
    cudnn.benchmark = true
 end
 opt.force_cudnn = opt.force_cudnn == 1
+opt.enable_tta = opt.enable_tta == 1
 local ART_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "art")
 local PHOTO_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "photo")
 local art_model = {
@@ -313,11 +315,14 @@ function APIHandler:post()
       self:write("client disconnected")
       return
    end
-
-   if tta_level == 0 then
-      tta_level = auto_tta_level(x, scale)
-   end
-   if not (tta_level == 0 or tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
+   if opt.enable_tta then
+      if tta_level == 0 then
+	 tta_level = auto_tta_level(x, scale)
+      end
+      if not (tta_level == 0 or tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then
+	 tta_level = 1
+      end
+   else
       tta_level = 1
    end
    if style ~= "art" then