瀏覽代碼

Improve output file format

supported format variable:
  %s: basename of the source filename
  %d: sequence number

example:
   output/2x/%s.png
   output/%d.png
   output/%06d_%s.png
   output/%s_%d.png
nagadomi 9 年之前
父節點
當前提交
1464b0db3e
共有 1 個文件被更改,包括 30 次插入16 次删除
  1. 30 16
      waifu2x.lua

+ 30 - 16
waifu2x.lua

@@ -10,6 +10,33 @@ local alpha_util = require 'alpha_util'
 
 torch.setdefaulttensortype('torch.FloatTensor')
 
+local function format_output(opt, src, no)
+   no = no or 1
+   local name = path.basename(src)
+   local e = path.extension(name)
+   local basename = name:sub(0, name:len() - e:len())
+   
+   if opt.o == "(auto)" then
+      return path.join(path.dirname(src), string.format("%s_%s.png", basename, opt.m))
+   else
+      local basename_pos = opt.o:find("%%s")
+      local no_pos = opt.o:find("%%%d*d")
+      if basename_pos ~= nil and no_pos ~= nil then
+	 if basename_pos < no_pos then
+	    return string.format(opt.o, basename, no)
+	 else
+	    return string.format(opt.o, no, basename)
+	 end
+      elseif basename_pos ~= nil then
+	 return string.format(opt.o, basename)
+      elseif no_pos ~= nil then
+	 return string.format(opt.o, no)
+      else
+	 return opt.o
+      end
+   end
+end
+
 local function convert_image(opt)
    local x, meta = image_loader.load_float(opt.i)
    local alpha = meta.alpha
@@ -24,12 +51,7 @@ local function convert_image(opt)
       scale_f = reconstruct.scale
       image_f = reconstruct.image
    end
-   if opt.o == "(auto)" then
-      local name = path.basename(opt.i)
-      local e = path.extension(name)
-      local base = name:sub(0, name:len() - e:len())
-      opt.o = path.join(path.dirname(opt.i), string.format("%s_%s.png", base, opt.m))
-   end
+   opt.o = format_output(opt, opt.i)
    if opt.m == "noise" then
       local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
       local model = torch.load(model_path, "ascii")
@@ -115,7 +137,8 @@ local function convert_frames(opt)
    end
    fp:close()
    for i = 1, #lines do
-      if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
+      local output = format_output(opt, lines[i], i)
+      if opt.resume == 0 or path.exists(output) == false then
 	 local x, meta = image_loader.load_float(lines[i])
 	 local alpha = meta.alpha
 	 local new_x = nil
@@ -134,15 +157,6 @@ local function convert_frames(opt)
 	 else
 	    error("undefined method:" .. opt.method)
 	 end
-	 local output = nil
-	 if opt.o == "(auto)" then
-	    local name = path.basename(lines[i])
-	    local e = path.extension(name)
-	    local base = name:sub(0, name:len() - e:len())
-	    output = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
-	 else
-	    output = string.format(opt.o, i)
-	 end
 	 image_loader.save_png(output, new_x, 
 			       tablex.update({depth = opt.depth, inplace = true}, meta))
 	 xlua.progress(i, #lines)