Browse Source

Add support for TTA level

nagadomi 9 năm trước cách đây
mục cha
commit
0b949c05a7
2 tập tin đã thay đổi với 84 bổ sung44 xóa
  1. 67 40
      lib/reconstruct.lua
  2. 17 4
      waifu2x.lua

+ 67 - 40
lib/reconstruct.lua

@@ -209,63 +209,90 @@ function reconstruct.scale(model, scale, x, block_size, upsampling_filter)
    end
    end
    return x
    return x
 end
 end
-local function tta(f, model, x, block_size)
+local function tr_f(a)
+   return a:transpose(2, 3):contiguous() 
+end
+local function itr_f(a)
+   return a:transpose(2, 3):contiguous()
+end
+local augmented_patterns = {
+   {
+      forward = function (a) return a end,
+      backward = function (a) return a end
+   },
+   {
+      forward = function (a) return image.hflip(a) end,
+      backward = function (a) return image.hflip(a) end
+   },
+   {
+      forward = function (a) return image.vflip(a) end,
+      backward = function (a) return image.vflip(a) end
+   },
+   {
+      forward = function (a) return image.hflip(image.vflip(a)) end,
+      backward = function (a) return image.vflip(image.hflip(a)) end
+   },
+   {
+      forward = function (a) return tr_f(a) end,
+      backward = function (a) return itr_f(a) end
+   },
+   {
+      forward = function (a) return image.hflip(tr_f(a)) end,
+      backward = function (a) return itr_f(image.hflip(a)) end
+   },
+   {
+      forward = function (a) return image.vflip(tr_f(a)) end,
+      backward = function (a) return itr_f(image.vflip(a)) end
+   },
+   {
+      forward = function (a) return image.hflip(image.vflip(tr_f(a))) end,
+      backward = function (a) return itr_f(image.vflip(image.hflip(a))) end
+   }
+}
+local function get_augmented_patterns(n)
+   if n == 2 then
+      return {augmented_patterns[1], augmented_patterns[5]}
+   elseif n == 4 then
+      return {augmented_patterns[1], augmented_patterns[5],
+	      augmented_patterns[2], augmented_patterns[7]}
+   elseif n == 8 then
+      return augmented_patterns
+   else
+      error("unsupported TTA level: " .. n)
+   end
+end
+local function tta(f, n, model, x, block_size)
    local average = nil
    local average = nil
    local offset = reconstruct.offset_size(model)
    local offset = reconstruct.offset_size(model)
-   for i = 1, 4 do 
-      local flip_f, iflip_f
-      if i == 1 then
-	 flip_f = function (a) return a end
-	 iflip_f = function (a) return a end
-      elseif i == 2 then
-	 flip_f = image.vflip
-	 iflip_f = image.vflip
-      elseif i == 3 then
-	 flip_f = image.hflip
-	 iflip_f = image.hflip
-      elseif i == 4 then
-	 flip_f = function (a) return image.hflip(image.vflip(a)) end
-	 iflip_f = function (a) return image.vflip(image.hflip(a)) end
-      end
-      for j = 1, 2 do
-	 local tr_f, itr_f
-	 if j == 1 then
-	    tr_f = function (a) return a end
-	    itr_f = function (a) return a end
-	 elseif j == 2 then
-	    tr_f = function(a) return a:transpose(2, 3):contiguous() end
-	    itr_f = function(a) return a:transpose(2, 3):contiguous() end
-	 end
-	 local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
-				     offset, block_size)))
-	 if not average then
-	    average = out
-	 else
-	    average:add(out)
-	 end
+   local augments = get_augmented_patterns(n)
+   for i = 1, #augments do 
+      local out = augments[i].backward(f(model, augments[i].forward(x), offset, block_size))
+      if not average then
+	 average = out
+      else
+	 average:add(out)
       end
       end
    end
    end
-   return average:div(8.0)
+   return average:div(#augments)
 end
 end
-function reconstruct.image_tta(model, x, block_size)
+function reconstruct.image_tta(model, n, x, block_size)
    if reconstruct.is_rgb(model) then
    if reconstruct.is_rgb(model) then
-      return tta(reconstruct.image_rgb, model, x, block_size)
+      return tta(reconstruct.image_rgb, n, model, x, block_size)
    else
    else
-      return tta(reconstruct.image_y, model, x, block_size)
+      return tta(reconstruct.image_y, n, model, x, block_size)
    end
    end
 end
 end
-function reconstruct.scale_tta(model, scale, x, block_size, upsampling_filter)
+function reconstruct.scale_tta(model, n, scale, x, block_size, upsampling_filter)
    if reconstruct.is_rgb(model) then
    if reconstruct.is_rgb(model) then
       local f = function (model, x, offset, block_size)
       local f = function (model, x, offset, block_size)
 	 return reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
 	 return reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
       end
       end
-      return tta(f, model, x, block_size)
-		 
+      return tta(f, n, model, x, block_size)
    else
    else
       local f = function (model, x, offset, block_size)
       local f = function (model, x, offset, block_size)
 	 return reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
 	 return reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
       end
       end
-      return tta(f, model, x, block_size)
+      return tta(f, n, model, x, block_size)
    end
    end
 end
 end
 
 

+ 17 - 4
waifu2x.lua

@@ -44,8 +44,14 @@ local function convert_image(opt)
    local scale_f, image_f
    local scale_f, image_f
 
 
    if opt.tta == 1 then
    if opt.tta == 1 then
-      scale_f = reconstruct.scale_tta
-      image_f = reconstruct.image_tta
+      scale_f = function(model, scale, x, block_size, upsampling_filter)
+	 return reconstruct.scale_tta(model, opt.tta_level,
+				      scale, x, block_size, upsampling_filter)
+      end
+      image_f = function(model, x, block_size)
+	 return reconstruct.image_tta(model, opt.tta_level,
+				      x, block_size)
+      end
    else
    else
       scale_f = reconstruct.scale
       scale_f = reconstruct.scale
       image_f = reconstruct.image
       image_f = reconstruct.image
@@ -119,8 +125,14 @@ local function convert_frames(opt)
    local noise_model = {}
    local noise_model = {}
    local scale_f, image_f
    local scale_f, image_f
    if opt.tta == 1 then
    if opt.tta == 1 then
-      scale_f = reconstruct.scale_tta
-      image_f = reconstruct.image_tta
+      scale_f = function(model, scale, x, block_size, upsampling_filter)
+	 return reconstruct.scale_tta(model, opt.tta_level,
+				      scale, x, block_size, upsampling_filter)
+      end
+      image_f = function(model, x, block_size)
+	 return reconstruct.image_tta(model, opt.tta_level,
+				      x, block_size)
+      end
    else
    else
       scale_f = reconstruct.scale
       scale_f = reconstruct.scale
       image_f = reconstruct.image
       image_f = reconstruct.image
@@ -226,6 +238,7 @@ local function waifu2x()
    cmd:option("-resume", 0, "skip existing files (0|1)")
    cmd:option("-resume", 0, "skip existing files (0|1)")
    cmd:option("-thread", -1, "number of CPU threads")
    cmd:option("-thread", -1, "number of CPU threads")
    cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
    cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
+   cmd:option("-tta_level", 8, 'TTA level (2|4|8)')
    cmd:option("-upsampling_filter", "Box", 'upsampling filter (for dev)')
    cmd:option("-upsampling_filter", "Box", 'upsampling filter (for dev)')
    
    
    local opt = cmd:parse(arg)
    local opt = cmd:parse(arg)