Pārlūkot izejas kodu

Support for binary model file

nagadomi 6 gadi atpakaļ
vecāks
revīzija
20abbbc68e
2 mainītis faili ar 17 papildinājumiem un 15 dzēšanām
  1. 3 2
      lib/w2nn.lua
  2. 14 13
      waifu2x.lua

+ 3 - 2
lib/w2nn.lua

@@ -53,8 +53,9 @@ else
    end
    pcall(load_cudnn)
 
-   function w2nn.load_model(model_path, force_cudnn)
-      local model = torch.load(model_path, "ascii")
+   function w2nn.load_model(model_path, force_cudnn, mode)
+      mode = mode or "ascii"
+      local model = torch.load(model_path, mode)
       if force_cudnn then
 	 model = cudnn.convert(model, cudnn)
       end

+ 14 - 13
waifu2x.lua

@@ -62,7 +62,7 @@ local function convert_image(opt)
    opt.o = format_output(opt, opt.i)
    if opt.m == "noise" then
       local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
-      local model = w2nn.load_model(model_path, opt.force_cudnn)
+      local model = w2nn.load_model(model_path, opt.force_cudnn, opt.load_mode)
       if not model then
 	 error("Load Error: " .. model_path)
       end
@@ -74,7 +74,7 @@ local function convert_image(opt)
       end
    elseif opt.m == "scale" then
       local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
-      local model = w2nn.load_model(model_path, opt.force_cudnn)
+      local model = w2nn.load_model(model_path, opt.force_cudnn, opt.load_mode)
       if not model then
 	 error("Load Error: " .. model_path)
       end
@@ -90,7 +90,7 @@ local function convert_image(opt)
       if path.exists(model_path) then
 	 local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
 	 local t, scale_model = pcall(w2nn.load_model, scale_model_path, opt.force_cudnn)
-	 local model = w2nn.load_model(model_path, opt.force_cudnn)
+	 local model = w2nn.load_model(model_path, opt.force_cudnn, opt.load_mode)
 	 if not t then
 	    scale_model = model
 	 end
@@ -103,9 +103,9 @@ local function convert_image(opt)
 	 end
       else
 	 local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
-	 local noise_model = w2nn.load_model(noise_model_path, opt.force_cudnn)
+	 local noise_model = w2nn.load_model(noise_model_path, opt.force_cudnn, opt.load_mode)
 	 local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
-	 local scale_model = w2nn.load_model(scale_model_path, opt.force_cudnn)
+	 local scale_model = w2nn.load_model(scale_model_path, opt.force_cudnn, opt.load_mode)
 	 local t = sys.clock()
 	 x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
 	 x = image_f(noise_model, x, opt.crop_size, opt.batch_size)
@@ -117,7 +117,7 @@ local function convert_image(opt)
       end
    elseif opt.m == "user" then
       local model_path = opt.model_path
-      local model = w2nn.load_model(model_path, opt.force_cudnn)
+      local model = w2nn.load_model(model_path, opt.force_cudnn, opt.load_mode)
       if not model then
 	 error("Load Error: " .. model_path)
       end
@@ -159,27 +159,27 @@ local function convert_frames(opt)
    end
    if opt.m == "scale" then
       model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
-      scale_model = w2nn.load_model(model_path, opt.force_cudnn)
+      scale_model = w2nn.load_model(model_path, opt.force_cudnn, opt.load_mode)
    elseif opt.m == "noise" then
       model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
-      noise_model[opt.noise_level] = w2nn.load_model(model_path, opt.force_cudnn)
+      noise_model[opt.noise_level] = w2nn.load_model(model_path, opt.force_cudnn, opt.load_mode)
    elseif opt.m == "noise_scale" then
       local model_path = path.join(opt.model_dir, ("noise%d_scale%.1fx_model.t7"):format(opt.noise_level, opt.scale))
       if path.exists(model_path) then
-	 noise_scale_model[opt.noise_level] = w2nn.load_model(model_path, opt.force_cudnn)
+	 noise_scale_model[opt.noise_level] = w2nn.load_model(model_path, opt.force_cudnn, opt.load_mode)
 	 model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
-	 t, scale_model = pcall(w2nn.load_model, model_path, opt.force_cudnn)
+	 t, scale_model = pcall(w2nn.load_model, model_path, opt.force_cudnn, opt.load_mode)
 	 if not t then
 	    scale_model = noise_scale_model[opt.noise_level]
 	 end
       else
 	 model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
-	 scale_model = w2nn.load_model(model_path, opt.force_cudnn)
+	 scale_model = w2nn.load_model(model_path, opt.force_cudnn, opt.load_mode)
 	 model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
-	 noise_model[opt.noise_level] = w2nn.load_model(model_path, opt.force_cudnn)
+	 noise_model[opt.noise_level] = w2nn.load_model(model_path, opt.force_cudnn, opt.load_mode)
       end
    elseif opt.m == "user" then
-      user_model = w2nn.load_model(opt.model_path, opt.force_cudnn)
+      user_model = w2nn.load_model(opt.model_path, opt.force_cudnn, opt.load_mode)
    end
    local fp = io.open(opt.l)
    if not fp then
@@ -268,6 +268,7 @@ local function waifu2x()
    cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)')
    cmd:option("-q", 0, 'quiet (0|1)')
    cmd:option("-gpu", 1, 'Device ID')
+   cmd:option("-load_mode", "ascii", "ascii/binary")
 
    local opt = cmd:parse(arg)
    if opt.method:len() > 0 then