Forráskód Böngészése

Fix model loading error when using ukbench model; Add error handling

nagadomi 9 éve
szülő
commit
1f91548c6e
1 módosított fájl, 45 hozzáadás és 19 törlés
  1. 45 19
      waifu2x.lua

+ 45 - 19
waifu2x.lua

@@ -17,23 +17,34 @@ local function convert_image(opt)
       local name = path.basename(opt.i)
       local e = path.extension(name)
       local base = name:sub(0, name:len() - e:len())
-      opt.o = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
+      opt.o = path.join(path.dirname(opt.i), string.format("%s_%s.png", base, opt.m))
    end
    if opt.m == "noise" then
-      local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
-      --local srcnn = require 'lib/srcnn'
-      --local model = srcnn.waifu2x("rgb"):cuda()
-      model:evaluate()
+      local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
+      local model = torch.load(model_path, "ascii")
+      if not model then
+	 error("Load Error: " .. model_path)
+      end
       new_x = reconstruct.image(model, x, opt.crop_size)
    elseif opt.m == "scale" then
-      local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
-      model:evaluate()
+      local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+      local model = torch.load(model_path, "ascii")
+      if not model then
+	 error("Load Error: " .. model_path)
+      end
       new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
    elseif opt.m == "noise_scale" then
-      local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
-      local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
-      noise_model:evaluate()
-      scale_model:evaluate()
+      local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
+      local noise_model = torch.load(noise_model_path, "ascii")
+      local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+      local scale_model = torch.load(scale_model_path, "ascii")
+      
+      if not noise_model then
+	 error("Load Error: " .. noise_model_path)
+      end
+      if not scale_model then
+	 error("Load Error: " .. scale_model_path)
+      end
       x = reconstruct.image(noise_model, x)
       new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
    else
@@ -43,15 +54,30 @@ local function convert_image(opt)
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 local function convert_frames(opt)
-   local noise1_model = torch.load(path.join(opt.model_dir, "noise1_model.t7"), "ascii")
-   local noise2_model = torch.load(path.join(opt.model_dir, "noise2_model.t7"), "ascii")
-   local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
-
-   noise1_model:evaluate()
-   noise2_model:evaluate()
-   scale_model:evaluate()
-   
+   local noise1_model, noise2_model, scale_model
+   if opt.m == "scale" then
+      local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
+      scale_model = torch.load(model_path, "ascii")
+      if not scale_model then
+	 error("Load Error: " .. model_path)
+      end
+   elseif opt.m == "noise" and opt.noise_level == 1 then
+      local model_path = path.join(opt.model_dir, "noise1_model.t7")
+      noise1_model = torch.load(model_path, "ascii")
+      if not noise1_model then
+	 error("Load Error: " .. model_path)
+      end
+   elseif opt.m == "noise" and opt.noise_level == 2 then
+      local model_path = path.join(opt.model_dir, "noise2_model.t7")
+      noise2_model = torch.load(model_path, "ascii")
+      if not noise2_model then
+	 error("Load Error: " .. model_path)
+      end
+   end
    local fp = io.open(opt.l)
+   if not fp then
+      error("Open Error: " .. opt.l)
+   end
    local count = 0
    local lines = {}
    for line in fp:lines() do