Browse Source

Add upsampling_filter option

nagadomi 9 years ago
parent
commit
30fe5db735
6 changed files with 36 additions and 21 deletions
  1. 2 2
      lib/pairwise_transform.lua
  2. 16 10
      lib/reconstruct.lua
  3. 1 0
      lib/settings.lua
  4. 2 1
      train.lua
  5. 5 4
      waifu2x.lua
  6. 10 4
      web.lua

+ 2 - 2
lib/pairwise_transform.lua

@@ -92,11 +92,11 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
    if options.gamma_correction then
       x = iproc.scale(iproc.scale_with_gamma22(y, y:size(3) * down_scale,
 					       y:size(2) * down_scale, downsampling_filter),
-		      y:size(3), y:size(2))
+		      y:size(3), y:size(2), options.upsampling_filter)
    else
       x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
 				  y:size(2) * down_scale, downsampling_filter),
-		      y:size(3), y:size(2))
+		      y:size(3), y:size(2), options.upsampling_filter)
    end
    x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
 		  x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)

+ 16 - 10
lib/reconstruct.lua

@@ -105,10 +105,11 @@ function reconstruct.image_y(model, x, offset, block_size)
    
    return output
 end
-function reconstruct.scale_y(model, scale, x, offset, block_size)
+function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
+   upsampling_filter = upsampling_filter or "Box"
    block_size = block_size or 128
    local x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
-   x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
+   x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
    if x:size(2) * x:size(3) > 2048*2048 then
       collectgarbage()
    end
@@ -173,9 +174,10 @@ function reconstruct.image_rgb(model, x, offset, block_size)
 
    return output
 end
-function reconstruct.scale_rgb(model, scale, x, offset, block_size)
+function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
+   upsampling_filter = upsampling_filter or "Box"
    block_size = block_size or 128
-   x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
+   x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
    if x:size(2) * x:size(3) > 2048*2048 then
       collectgarbage()
    end
@@ -230,7 +232,7 @@ function reconstruct.image(model, x, block_size)
    end
    return x
 end
-function reconstruct.scale(model, scale, x, block_size)
+function reconstruct.scale(model, scale, x, block_size, upsampling_filter)
    local i2rgb = false
    if x:size(1) == 1 then
       local new_x = torch.Tensor(3, x:size(2), x:size(3))
@@ -242,10 +244,14 @@ function reconstruct.scale(model, scale, x, block_size)
    end
    if reconstruct.is_rgb(model) then
       x = reconstruct.scale_rgb(model, scale, x,
-				reconstruct.offset_size(model), block_size)
+				reconstruct.offset_size(model),
+				block_size,
+				upsampling_filter)
    else
       x = reconstruct.scale_y(model, scale, x,
-			      reconstruct.offset_size(model), block_size)
+			      reconstruct.offset_size(model),
+			      block_size,
+			      upsampling_filter)
    end
    if i2rgb then
       x = image.rgb2y(x)
@@ -297,16 +303,16 @@ function reconstruct.image_tta(model, x, block_size)
       return tta(reconstruct.image_y, model, x, block_size)
    end
 end
-function reconstruct.scale_tta(model, scale, x, block_size)
+function reconstruct.scale_tta(model, scale, x, block_size, upsampling_filter)
    if reconstruct.is_rgb(model) then
       local f = function (model, x, offset, block_size)
-	 return reconstruct.scale_rgb(model, scale, x, offset, block_size)
+	 return reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
       end
       return tta(f, model, x, block_size)
 		 
    else
       local f = function (model, x, offset, block_size)
-	 return reconstruct.scale_y(model, scale, x, offset, block_size)
+	 return reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
       end
       return tta(f, model, x, block_size)
    end

+ 1 - 0
lib/settings.lua

@@ -50,6 +50,7 @@ cmd:option("-save_history", 0, 'save all model (0|1)')
 cmd:option("-plot", 0, 'plot loss chart(0|1)')
 cmd:option("-downsampling_filters", "Box,Catrom", '(comma separated)downsampling filters for 2x scale training. (Point,Box,Triangle,Hermite,Hanning,Hamming,Blackman,Gaussian,Quadratic,Cubic,Catrom,Mitchell,Lanczos,Bessel,Sinc)')
 cmd:option("-gamma_correction", 0, 'Resizing with colorspace correction(sRGB:gamma 2.2) in scale training (0|1)')
+cmd:option("-upsampling_filter", "Box", 'upsampling filter for 2x scale training (dev)')
 
 local function to_bool(settings, name)
    if settings[name] == 1 then

+ 2 - 1
train.lua

@@ -15,7 +15,7 @@ local pairwise_transform = require 'pairwise_transform'
 local image_loader = require 'image_loader'
 
 local function save_test_scale(model, rgb, file)
-   local up = reconstruct.scale(model, settings.scale, rgb)
+   local up = reconstruct.scale(model, settings.scale, rgb, 128, settings.upsampling_filter)
    image.save(file, up)
 end
 local function save_test_jpeg(model, rgb, file)
@@ -113,6 +113,7 @@ local function transformer(x, is_validation, n, offset)
 				      n,
 				      {
 					 downsampling_filters = settings.downsampling_filters,
+					 upsampling_filter = settings.upsampling_filter,
 					 random_half_rate = settings.random_half_rate,
 					 random_color_noise_rate = random_color_noise_rate,
 					 random_overlay_rate = random_overlay_rate,

+ 5 - 4
waifu2x.lua

@@ -44,7 +44,7 @@ local function convert_image(opt)
 	 error("Load Error: " .. model_path)
       end
       x = alpha_util.make_border(x, alpha, reconstruct.offset_size(model))
-      new_x = scale_f(model, opt.scale, x, opt.crop_size)
+      new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
       new_x = alpha_util.composite(new_x, alpha, model)
    elseif opt.m == "noise_scale" then
       local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
@@ -60,7 +60,7 @@ local function convert_image(opt)
       end
       x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
       x = image_f(noise_model, x, opt.crop_size)
-      new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
+      new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
       new_x = alpha_util.composite(new_x, alpha, scale_model)
    else
       error("undefined method:" .. opt.method)
@@ -122,12 +122,12 @@ local function convert_frames(opt)
 	    new_x = alpha_util.composite(new_x, alpha)
 	 elseif opt.m == "scale" then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
+	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 elseif opt.m == "noise_scale" then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
 	    x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
-	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
+	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, upsampling_filter)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 else
 	    error("undefined method:" .. opt.method)
@@ -169,6 +169,7 @@ local function waifu2x()
    cmd:option("-resume", 0, "skip existing files (0|1)")
    cmd:option("-thread", -1, "number of CPU threads")
    cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
+   cmd:option("-upsampling_filter", "Box", 'upsampling filter (for dev)')
    
    local opt = cmd:parse(arg)
    if opt.thread > 0 then

+ 10 - 4
web.lua

@@ -25,6 +25,8 @@ cmd:text("waifu2x-api")
 cmd:text("Options:")
 cmd:option("-port", 8812, 'listen port')
 cmd:option("-gpu", 1, 'Device ID')
+cmd:option("-upsampling_filter", "Box", 'Upsampling filter (for dev)')
+cmd:option("-crop_size", 128, 'patch size per process')
 cmd:option("-thread", -1, 'number of CPU threads')
 local opt = cmd:parse(arg)
 cutorch.setDevice(opt.gpu)
@@ -142,10 +144,12 @@ local function convert(x, alpha, options)
 	    x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(art_scale2_model))
 	 end
 	 if options.method == "scale" then
-	    x = reconstruct.scale(art_scale2_model, 2.0, x)
+	    x = reconstruct.scale(art_scale2_model, 2.0, x,
+				  opt.crop_size, opt.upsampling_filter)
 	    if alpha then
 	       if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
-		  alpha = reconstruct.scale(art_scale2_model, 2.0, alpha)
+		  alpha = reconstruct.scale(art_scale2_model, 2.0, alpha,
+					    opt.crop_size, opt.upsampling_filter)
 		  image_loader.save_png(alpha_cache_file, alpha)
 	       end
 	    end
@@ -165,10 +169,12 @@ local function convert(x, alpha, options)
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
 	 end
 	 if options.method == "scale" then
-	    x = reconstruct.scale(photo_scale2_model, 2.0, x)
+	    x = reconstruct.scale(photo_scale2_model, 2.0, x,
+				  opt.crop_size, opt.upsampling_filter)
 	    if alpha then
 	       if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
-		  alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha)
+		  alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha,
+					    opt.crop_size, opt.upsampling_filter)
 		  image_loader.save_png(alpha_cache_file, alpha)
 	       end
 	    end