Procházet zdrojové kódy

Add support for noise level 3

nagadomi před 9 roky
rodič
revize
8a799e2d56
3 změnil soubory, kde provedl 18 přidání a 38 odebrání
  1. 2 1
      lib/pairwise_transform.lua
  2. 1 1
      lib/settings.lua
  3. 15 36
      waifu2x.lua

+ 2 - 1
lib/pairwise_transform.lua

@@ -168,7 +168,8 @@ function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
       if level == 1 then
 	 return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
 					 size, offset, n, options)
-      elseif level == 2 then
+      elseif level == 2 or level == 3 then
+	 -- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
 	 local r = torch.uniform()
 	 if r > 0.6 then
 	    return pairwise_transform.jpeg_(src, {torch.random(27, 70)},

+ 1 - 1
lib/settings.lua

@@ -24,7 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
 cmd:option("-test", "images/miku_small.png", 'path to test image')
 cmd:option("-model_dir", "./models", 'model directory')
 cmd:option("-method", "scale", 'method to training (noise|scale)')
-cmd:option("-noise_level", 1, '(1|2)')
+cmd:option("-noise_level", 1, '(1|2|3)')
 cmd:option("-style", "art", '(art|photo)')
 cmd:option("-color", 'rgb', '(y|rgb)')
 cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')

+ 15 - 36
waifu2x.lua

@@ -69,7 +69,8 @@ local function convert_image(opt)
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 local function convert_frames(opt)
-   local model_path, noise1_model, noise2_model, scale_model
+   local model_path, scale_model
+   local noise_model = {}
    local scale_f, image_f
    if opt.tta == 1 then
       scale_f = reconstruct.scale_tta
@@ -84,16 +85,10 @@ local function convert_frames(opt)
       if not scale_model then
 	 error("Load Error: " .. model_path)
       end
-   elseif opt.m == "noise" and opt.noise_level == 1 then
-      model_path = path.join(opt.model_dir, "noise1_model.t7")
-      noise1_model = torch.load(model_path, "ascii")
-      if not noise1_model then
-	 error("Load Error: " .. model_path)
-      end
-   elseif opt.m == "noise" and opt.noise_level == 2 then
-      model_path = path.join(opt.model_dir, "noise2_model.t7")
-      noise2_model = torch.load(model_path, "ascii")
-      if not noise2_model then
+   elseif opt.m == "noise" then
+      model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
+      noise_model[opt.noise_level] = torch.load(model_path, "ascii")
+      if not noise_model[opt.noise_level] then
 	 error("Load Error: " .. model_path)
       end
    elseif opt.m == "noise_scale" then
@@ -102,18 +97,10 @@ local function convert_frames(opt)
       if not scale_model then
 	 error("Load Error: " .. model_path)
       end
-      if opt.noise_level == 1 then
-	 model_path = path.join(opt.model_dir, "noise1_model.t7")
-	 noise1_model = torch.load(model_path, "ascii")
-	 if not noise1_model then
-	    error("Load Error: " .. model_path)
-	 end
-      elseif opt.noise_level == 2 then
-	 model_path = path.join(opt.model_dir, "noise2_model.t7")
-	 noise2_model = torch.load(model_path, "ascii")
-	 if not noise2_model then
-	    error("Load Error: " .. model_path)
-	 end
+      model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
+      noise_model[opt.noise_level] = torch.load(model_path, "ascii")
+      if not noise_model[opt.noise_level] then
+	 error("Load Error: " .. model_path)
       end
    end
    local fp = io.open(opt.l)
@@ -130,24 +117,16 @@ local function convert_frames(opt)
       if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
 	 local x, alpha = image_loader.load_float(lines[i])
 	 local new_x = nil
-	 if opt.m == "noise" and opt.noise_level == 1 then
-	    new_x = image_f(noise1_model, x, opt.crop_size)
-	    new_x = alpha_util.composite(new_x, alpha)
-	 elseif opt.m == "noise" and opt.noise_level == 2 then
-	    new_x = image_f(noise2_model, x, opt.crop_size)
+	 if opt.m == "noise" then
+	    new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
 	    new_x = alpha_util.composite(new_x, alpha)
 	 elseif opt.m == "scale" then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
 	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
-	 elseif opt.m == "noise_scale" and opt.noise_level == 1 then
-	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-	    x = image_f(noise1_model, x, opt.crop_size)
-	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
-	    new_x = alpha_util.composite(new_x, alpha, scale_model)
-	 elseif opt.m == "noise_scale" and opt.noise_level == 2 then
+	 elseif opt.m == "noise_scale" then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-	    x = image_f(noise2_model, x, opt.crop_size)
+	    x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
 	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 else
@@ -185,7 +164,7 @@ local function waifu2x()
    cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
    cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
    cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
-   cmd:option("-noise_level", 1, '(1|2)')
+   cmd:option("-noise_level", 1, '(1|2|3)')
    cmd:option("-crop_size", 128, 'patch size per process')
    cmd:option("-resume", 0, "skip existing files (0|1)")
    cmd:option("-thread", -1, "number of CPU threads")