Browse Source

Add support for new noise_scale method

nagadomi 9 years ago
parent
commit
37bc7a5eea
1 changed files with 63 additions and 27 deletions
  1. 63 27
      waifu2x.lua

+ 63 - 27
waifu2x.lua

@@ -73,23 +73,41 @@ local function convert_image(opt)
       new_x = alpha_util.composite(new_x, alpha, model)
       new_x = alpha_util.composite(new_x, alpha, model)
       print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
       print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
    elseif opt.m == "noise_scale" then
    elseif opt.m == "noise_scale" then
-      local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
-      local noise_model = torch.load(noise_model_path, "ascii")
-      local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
-      local scale_model = torch.load(scale_model_path, "ascii")
-      
-      if not noise_model then
-	 error("Load Error: " .. noise_model_path)
-      end
-      if not scale_model then
-	 error("Load Error: " .. scale_model_path)
+      local model_path = path.join(opt.model_dir, ("noise%d_scale%.1fx_model.t7"):format(opt.noise_level, opt.scale))
+      if path.exists(model_path) then
+	 local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+	 local scale_model = torch.load(scale_model_path, "ascii")
+	 local model = torch.load(model_path, "ascii")
+	 if not model then
+	    error("Load Error: " .. model_path)
+	 end
+	 if not scale_model then
+	    error("Load Error: " .. model_path)
+	 end
+	 local t = sys.clock()
+	 x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
+	 new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
+	 new_x = alpha_util.composite(new_x, alpha, scale_model)
+	 print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
+      else
+	 local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
+	 local noise_model = torch.load(noise_model_path, "ascii")
+	 local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+	 local scale_model = torch.load(scale_model_path, "ascii")
+	 
+	 if not noise_model then
+	    error("Load Error: " .. noise_model_path)
+	 end
+	 if not scale_model then
+	    error("Load Error: " .. scale_model_path)
+	 end
+	 local t = sys.clock()
+	 x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
+	 x = image_f(noise_model, x, opt.crop_size)
+	 new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
+	 new_x = alpha_util.composite(new_x, alpha, scale_model)
+	 print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
       end
       end
-      local t = sys.clock()
-      x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-      x = image_f(noise_model, x, opt.crop_size)
-      new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
-      new_x = alpha_util.composite(new_x, alpha, scale_model)
-      print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
    else
    else
       error("undefined method:" .. opt.method)
       error("undefined method:" .. opt.method)
    end
    end
@@ -97,6 +115,7 @@ local function convert_image(opt)
 end
 end
 local function convert_frames(opt)
 local function convert_frames(opt)
    local model_path, scale_model
    local model_path, scale_model
+   local noise_scale_model = {}
    local noise_model = {}
    local noise_model = {}
    local scale_f, image_f
    local scale_f, image_f
    if opt.tta == 1 then
    if opt.tta == 1 then
@@ -119,15 +138,28 @@ local function convert_frames(opt)
 	 error("Load Error: " .. model_path)
 	 error("Load Error: " .. model_path)
       end
       end
    elseif opt.m == "noise_scale" then
    elseif opt.m == "noise_scale" then
-      model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
-      scale_model = torch.load(model_path, "ascii")
-      if not scale_model then
-	 error("Load Error: " .. model_path)
-      end
-      model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
-      noise_model[opt.noise_level] = torch.load(model_path, "ascii")
-      if not noise_model[opt.noise_level] then
-	 error("Load Error: " .. model_path)
+      local model_path = path.join(opt.model_dir, ("noise%d_scale%.1fx_model.t7"):format(opt.noise_level, opt.scale))
+      if path.exists(model_path) then
+	 noise_scale_model[opt.noise_level] = torch.load(model_path, "ascii")
+	 if not noise_scale_model[opt.noise_level] then
+	    error("Load Error: " .. model_path)
+	 end
+	 model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+	 scale_model = torch.load(model_path, "ascii")
+	 if not scale_model then
+	    error("Load Error: " .. model_path)
+	 end
+      else
+	 model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+	 scale_model = torch.load(model_path, "ascii")
+	 if not scale_model then
+	    error("Load Error: " .. model_path)
+	 end
+	 model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
+	 noise_model[opt.noise_level] = torch.load(model_path, "ascii")
+	 if not noise_model[opt.noise_level] then
+	    error("Load Error: " .. model_path)
+	 end
       end
       end
    end
    end
    local fp = io.open(opt.l)
    local fp = io.open(opt.l)
@@ -155,8 +187,12 @@ local function convert_frames(opt)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 elseif opt.m == "noise_scale" then
 	 elseif opt.m == "noise_scale" then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-	    x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
-	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, upsampling_filter)
+	    if noise_scale_model[opt.noise_level] then
+	       new_x = scale_f(noise_scale_model[opt.noise_level], opt.scale, x, opt.crop_size, upsampling_filter)
+	    else
+	       x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
+	       new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, upsampling_filter)
+	    end
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 else
 	 else
 	    error("undefined method:" .. opt.method)
 	    error("undefined method:" .. opt.method)