Sfoglia il codice sorgente

Add support for method=noise_scale to tools/benchmark.lua

nagadomi 9 anni fa
parent
commit
83188c5ab7
1 ha cambiato i file con 86 aggiunte e 5 eliminazioni
  1. 86 5
      tools/benchmark.lua

+ 86 - 5
tools/benchmark.lua

@@ -17,7 +17,7 @@ cmd:text("Options:")
 cmd:option("-dir", "./data/test", 'test image directory')
 cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
 cmd:option("-model2_dir", "", 'model2 directory (optional)')
-cmd:option("-method", "scale", '(scale|noise)')
+cmd:option("-method", "scale", '(scale|noise|noise_scale)')
 cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))")
 cmd:option("-resize_blur", 1.0, 'blur parameter for resize')
 cmd:option("-color", "y", '(rgb|y)')
@@ -129,6 +129,22 @@ local function transform_scale(x, opt)
 		      opt.filter, opt.resize_blur)
 end
 
+local function transform_scale_jpeg(x, opt)
+   x = iproc.scale(x,
+		   x:size(3) * 0.5,
+		   x:size(2) * 0.5,
+		   opt.filter, opt.resize_blur)
+   for i = 1, opt.jpeg_times do
+      jpeg = gm.Image(x, "RGB", "DHW")
+      jpeg:format("jpeg")
+      jpeg:samplingFactors({1.0, 1.0, 1.0})
+      blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
+      jpeg:fromBlob(blob, len)
+      x = jpeg:toTensor("byte", "RGB", "DHW")
+   end
+   return iproc.byte2float(x)
+end
+
 local function benchmark(opt, x, input_func, model1, model2)
    local model1_mse = 0
    local model2_mse = 0
@@ -157,15 +173,45 @@ local function benchmark(opt, x, input_func, model1, model2)
 
       input = input_func(ground_truth, opt)
       t = sys.clock()
-      if input:size(3) == ground_truth:size(3) then
+      if opt.method == "scale" then
+	 model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
+	 if model2 then
+	    model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
+	 end
+	 baseline_output = baseline_scale(input, opt.baseline_filter)
+      elseif opt.method == "noise" then
 	 model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
 	 if model2 then
 	    model2_output = image_f(model2, input, opt.crop_size, opt.batch_size)
 	 end
-      else
-	 model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
+	 baseline_output = input
+      elseif opt.method == "noise_scale" then
+	 if model1.noise_scale_model then
+	    model1_output = scale_f(model1.noise_scale_model, 2.0,
+				    input, opt.crop_size, opt.batch_size)
+	 else
+	    if model1.noise_model then
+	       model1_output = image_f(model1.noise_model, input, opt.crop_size, opt.batch_size)
+	    else
+	       model1_output = input
+	    end
+	    model1_output = scale_f(model1.scale_model, 2.0, model1_output,
+				    opt.crop_size, opt.batch_size)
+	 end
 	 if model2 then
-	    model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
+	    if model2.noise_scale_model then
+	       model2_output = scale_f(model2.noise_scale_model, 2.0,
+				       input, opt.crop_size, opt.batch_size)
+	    else
+	       if model2.noise_model then
+		  model2_output = image_f(model2.noise_model, input,
+					  opt.crop_size, opt.batch_size)
+	       else
+		  model2_output = input
+	       end
+	       model2_output = scale_f(model2.scale_model, 2.0, model2_output,
+				       opt.crop_size, opt.batch_size)
+	    end
 	 end
 	 baseline_output = baseline_scale(input, opt.baseline_filter)
       end
@@ -271,9 +317,36 @@ end
 function load_model(filename)
    return torch.load(filename, "ascii")
 end
+function load_noise_scale_model(model_dir, noise_level)
+   local f = path.join(model_dir, string.format("noise%d_scale2.0x_model.t7", opt.noise_level))
+   local s1, noise_scale = pcall(load_model, f)
+   local model = {}
+   if not s1 then
+      f = path.join(model_dir, string.format("noise%d_model.t7", opt.noise_level))
+      local noise
+      s1, noise = pcall(load_model, f)
+      if not s1 then
+	 model.noise_model = nil
+	 print(model_dir .. "'s noise model is not found. benchmark will use only scale model.")
+      else
+	 model.noise_model = noise
+      end
+      f = path.join(model_dir, "scale2.0x_model.t7")
+      local scale
+      s1, scale = pcall(load_model, f)
+      if not s1 then
+	 return nil
+      end
+      model.scale_model = scale
+   else
+      model.noise_scale_model = noise_scale
+   end
+   return model
+end
 if opt.show_progress then
    print(opt)
 end
+
 if opt.method == "scale" then
    local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
    local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
@@ -300,4 +373,12 @@ elseif opt.method == "noise" then
    end
    local test_x = load_data(opt.dir)
    benchmark(opt, test_x, transform_jpeg, model1, model2)
+elseif opt.method == "noise_scale" then
+   local model2 = nil
+   local model1 = load_noise_scale_model(opt.model1_dir, opt.noise_level)
+   if opt.model2_dir:len() > 0 then
+      model2 = load_noise_scale_model(opt.model2_dir, opt.noise_level)
+   end
+   local test_x = load_data(opt.dir)
+   benchmark(opt, test_x, transform_scale_jpeg, model1, model2)
 end