瀏覽代碼

Add -tta option

The TTA mode:
- 8x slower than normal mode
- improves PSNR +0.1
nagadomi 9 年之前
父節點
當前提交
b335f3a9ad
共有 2 個文件被更改,包括 110 次插入15 次删除
  1. 59 0
      lib/reconstruct.lua
  2. 51 15
      waifu2x.lua

+ 59 - 0
lib/reconstruct.lua

@@ -206,5 +206,64 @@ function reconstruct.scale(model, scale, x, block_size)
 				 reconstruct.offset_size(model), block_size)
    end
 end
+local function tta(f, model, x, block_size)
+   local average = nil
+   local offset = reconstruct.offset_size(model)
+   for i = 1, 4 do 
+      local flip_f, iflip_f
+      if i == 1 then
+	 flip_f = function (a) return a end
+	 iflip_f = function (a) return a end
+      elseif i == 2 then
+	 flip_f = image.vflip
+	 iflip_f = image.vflip
+      elseif i == 3 then
+	 flip_f = image.hflip
+	 iflip_f = image.hflip
+      elseif i == 4 then
+	 flip_f = function (a) return image.hflip(image.vflip(a)) end
+	 iflip_f = function (a) return image.vflip(image.hflip(a)) end
+      end
+      for j = 1, 2 do
+	 local tr_f, itr_f
+	 if j == 1 then
+	    tr_f = function (a) return a end
+	    itr_f = function (a) return a end
+	 elseif j == 2 then
+	    tr_f = function(a) return a:transpose(2, 3):contiguous() end
+	    itr_f = function(a) return a:transpose(2, 3):contiguous() end
+	 end
+	 local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
+				     offset, block_size)))
+	 if not average then
+	    average = out
+	 else
+	    average:add(out)
+	 end
+      end
+   end
+   return average:div(8.0)
+end
+function reconstruct.image_tta(model, x, block_size)
+   if reconstruct.is_rgb(model) then
+      return tta(reconstruct.image_rgb, model, x, block_size)
+   else
+      return tta(reconstruct.image_y, model, x, block_size)
+   end
+end
+function reconstruct.scale_tta(model, scale, x, block_size)
+   if reconstruct.is_rgb(model) then
+      local f = function (model, x, offset, block_size)
+	 return reconstruct.scale_rgb(model, scale, x, offset, block_size)
+      end
+      return tta(f, model, x, block_size)
+		 
+   else
+      local f = function (model, x, offset, block_size)
+	 return reconstruct.scale_y(model, scale, x, offset, block_size)
+      end
+      return tta(f, model, x, block_size)
+   end
+end
 
 return reconstruct

+ 51 - 15
waifu2x.lua

@@ -13,6 +13,14 @@ local function convert_image(opt)
    local x, alpha = image_loader.load_float(opt.i)
    local new_x = nil
    local t = sys.clock()
+   local scale_f, image_f
+   if opt.tta == 1 then
+      scale_f = reconstruct.scale_tta
+      image_f = reconstruct.image_tta
+   else
+      scale_f = reconstruct.scale
+      image_f = reconstruct.image
+   end
    if opt.o == "(auto)" then
       local name = path.basename(opt.i)
       local e = path.extension(name)
@@ -25,14 +33,14 @@ local function convert_image(opt)
       if not model then
 	 error("Load Error: " .. model_path)
       end
-      new_x = reconstruct.image(model, x, opt.crop_size)
+      new_x = image_f(model, x, opt.crop_size)
    elseif opt.m == "scale" then
       local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
       local model = torch.load(model_path, "ascii")
       if not model then
 	 error("Load Error: " .. model_path)
       end
-      new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
+      new_x = scale_f(model, opt.scale, x, opt.crop_size)
    elseif opt.m == "noise_scale" then
       local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
       local noise_model = torch.load(noise_model_path, "ascii")
@@ -45,8 +53,8 @@ local function convert_image(opt)
       if not scale_model then
 	 error("Load Error: " .. scale_model_path)
       end
-      x = reconstruct.image(noise_model, x)
-      new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
+      x = image_f(noise_model, x)
+      new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
    else
       error("undefined method:" .. opt.method)
    end
@@ -54,25 +62,52 @@ local function convert_image(opt)
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 local function convert_frames(opt)
-   local noise1_model, noise2_model, scale_model
+   local model_path, noise1_model, noise2_model, scale_model
+   local scale_f, image_f
+   if opt.tta == 1 then
+      scale_f = reconstruct.scale_tta
+      image_f = reconstruct.image_tta
+   else
+      scale_f = reconstruct.scale
+      image_f = reconstruct.image
+   end
    if opt.m == "scale" then
-      local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+      model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
       scale_model = torch.load(model_path, "ascii")
       if not scale_model then
 	 error("Load Error: " .. model_path)
       end
    elseif opt.m == "noise" and opt.noise_level == 1 then
-      local model_path = path.join(opt.model_dir, "noise1_model.t7")
+      model_path = path.join(opt.model_dir, "noise1_model.t7")
       noise1_model = torch.load(model_path, "ascii")
       if not noise1_model then
 	 error("Load Error: " .. model_path)
       end
    elseif opt.m == "noise" and opt.noise_level == 2 then
-      local model_path = path.join(opt.model_dir, "noise2_model.t7")
+      model_path = path.join(opt.model_dir, "noise2_model.t7")
       noise2_model = torch.load(model_path, "ascii")
       if not noise2_model then
 	 error("Load Error: " .. model_path)
       end
+   elseif opt.m == "noise_scale" then
+      model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+      scale_model = torch.load(model_path, "ascii")
+      if not scale_model then
+	 error("Load Error: " .. model_path)
+      end
+      if opt.noise_level == 1 then
+	 model_path = path.join(opt.model_dir, "noise1_model.t7")
+	 noise1_model = torch.load(model_path, "ascii")
+	 if not noise1_model then
+	    error("Load Error: " .. model_path)
+	 end
+      elseif opt.noise_level == 2 then
+	 model_path = path.join(opt.model_dir, "noise2_model.t7")
+	 noise2_model = torch.load(model_path, "ascii")
+	 if not noise2_model then
+	    error("Load Error: " .. model_path)
+	 end
+      end
    end
    local fp = io.open(opt.l)
    if not fp then
@@ -89,17 +124,17 @@ local function convert_frames(opt)
 	 local x, alpha = image_loader.load_float(lines[i])
 	 local new_x = nil
 	 if opt.m == "noise" and opt.noise_level == 1 then
-	    new_x = reconstruct.image(noise1_model, x, opt.crop_size)
+	    new_x = image_f(noise1_model, x, opt.crop_size)
 	 elseif opt.m == "noise" and opt.noise_level == 2 then
-	    new_x = reconstruct.image(noise2_model, x)
+	    new_x = image_func(noise2_model, x)
 	 elseif opt.m == "scale" then
-	    new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
+	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	 elseif opt.m == "noise_scale" and opt.noise_level == 1 then
-	    x = reconstruct.image(noise1_model, x)
-	    new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
+	    x = image_f(noise1_model, x)
+	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	 elseif opt.m == "noise_scale" and opt.noise_level == 2 then
-	    x = reconstruct.image(noise2_model, x)
-	    new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
+	    x = image_f(noise2_model, x, opt.crop_size)
+	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	 else
 	    error("undefined method:" .. opt.method)
 	 end
@@ -139,6 +174,7 @@ local function waifu2x()
    cmd:option("-crop_size", 128, 'patch size per process')
    cmd:option("-resume", 0, "skip existing files (0|1)")
    cmd:option("-thread", -1, "number of CPU threads")
+   cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
    
    local opt = cmd:parse(arg)
    if opt.thread > 0 then