Преглед изворни кода

Add -batch_size option to waifu2x.lua/web.lua

nagadomi пре 9 година
родитељ
комит
afac4b52ab
3 измењених фајлова са 67 додато и 43 уклоњено
  1. 37 14
      lib/reconstruct.lua
  2. 19 19
      waifu2x.lua
  3. 11 10
      web.lua

+ 37 - 14
lib/reconstruct.lua

@@ -2,7 +2,8 @@ require 'image'
 local iproc = require 'iproc'
 local srcnn = require 'srcnn'
 
-local function reconstruct_nn(model, x, inner_scale, offset, block_size)
+local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_size)
+   batch_size = batch_size or 1
    if x:dim() == 2 then
       x = x:reshape(1, x:size(1), x:size(2))
    end
@@ -12,24 +13,46 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size)
    local output_block_size = block_size
    local output_size = output_block_size - offset * 2
    local output_size_in_input = input_block_size - math.ceil(offset / inner_scale) * 2
-   local input = torch.CudaTensor(1, ch, input_block_size, input_block_size)
+   local input_indexes = {}
+   local output_indexes = {}
    for i = 1, x:size(2), output_size_in_input do
       for j = 1, x:size(3), output_size_in_input do
 	 if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then
 	    local index = {{},
 			   {i, i + input_block_size - 1},
 			   {j, j + input_block_size - 1}}
-	    input:copy(x[index])
-	    local output = model:forward(input)
-	    output = output:view(ch, output_size, output_size)
 	    local ii = (i - 1) * inner_scale + 1
 	    local jj = (j - 1) * inner_scale + 1
 	    local output_index = {{}, { ii , ii + output_size - 1 },
 	       { jj, jj + output_size - 1}}
-	    new_x[output_index]:copy(output)
+	    table.insert(input_indexes, index)
+	    table.insert(output_indexes, output_index)
 	 end
       end
    end
+   local input = torch.Tensor(batch_size, ch, input_block_size, input_block_size)
+   local input_cuda = torch.CudaTensor(batch_size, ch, input_block_size, input_block_size)
+   for i = 1, #input_indexes, batch_size do
+      local c = 0
+      local output
+      for j = 0, batch_size - 1 do
+	 if i + j > #input_indexes then
+	    break
+	 end
+	 input[j+1]:copy(x[input_indexes[i + j]])
+	 c = c + 1
+      end
+      input_cuda:copy(input)
+      if c == batch_size then
+	 output = model:forward(input_cuda)
+      else
+	 output = model:forward(input_cuda:narrow(1, 1, c))
+      end
+      --output = output:view(batch_size, ch, output_size, output_size)
+      for j = 0, c - 1 do
+	 new_x[output_indexes[i + j]]:copy(output[j+1])
+      end
+   end
    return new_x
 end
 local reconstruct = {}
@@ -72,11 +95,11 @@ local function padding_params(x, model, block_size)
    p.pad_w2 = (w - input_offset) - p.x_w
    return p
 end
-function reconstruct.image_y(model, x, offset, block_size)
+function reconstruct.image_y(model, x, offset, block_size, batch_size)
    block_size = block_size or 128
    local p = padding_params(x, model, block_size)
    x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
-   local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size)
+   local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
    x = iproc.crop(x, p.pad_w1, p.pad_w2, p.pad_w1 + p.x_w, p.pad_w2 + p.x_h)
    y = iproc.crop(y, 0, 0, p.x_w, p.x_h)
    y[torch.lt(y, 0)] = 0
@@ -91,7 +114,7 @@ function reconstruct.image_y(model, x, offset, block_size)
    
    return output
 end
-function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
+function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size, upsampling_filter)
    upsampling_filter = upsampling_filter or "Box"
    block_size = block_size or 128
    local x_lanczos
@@ -107,7 +130,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_fil
    end
    x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
    x_lanczos = image.rgb2yuv(x_lanczos)
-   local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size)
+   local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
    y = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
@@ -122,14 +145,14 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_fil
    
    return output
 end
-function reconstruct.image_rgb(model, x, offset, block_size)
+function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
    block_size = block_size or 128
    local p = padding_params(x, model, block_size)
    x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
    if p.x_w * p.x_h > 2048*2048 then
       collectgarbage()
    end
-   local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size)
+   local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
    local output = iproc.crop(y, 0, 0, p.x_w, p.x_h)
    output[torch.lt(output, 0)] = 0
    output[torch.gt(output, 1)] = 1
@@ -139,7 +162,7 @@ function reconstruct.image_rgb(model, x, offset, block_size)
 
    return output
 end
-function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
+function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size, upsampling_filter)
    upsampling_filter = upsampling_filter or "Box"
    block_size = block_size or 128
    if not reconstruct.has_resize(model) then
@@ -151,7 +174,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_f
       collectgarbage()
    end
    local y
-   y = reconstruct_nn(model, x, p.inner_scale, offset, block_size)
+   y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
    local output = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
    output[torch.lt(output, 0)] = 0
    output[torch.gt(output, 1)] = 1

+ 19 - 19
waifu2x.lua

@@ -44,13 +44,13 @@ local function convert_image(opt)
    local scale_f, image_f
 
    if opt.tta == 1 then
-      scale_f = function(model, scale, x, block_size, upsampling_filter)
+      scale_f = function(model, scale, x, block_size, batch_size, batupsampling_filter)
 	 return reconstruct.scale_tta(model, opt.tta_level,
-				      scale, x, block_size, upsampling_filter)
+				      scale, x, block_size, batch_size, upsampling_filter)
       end
-      image_f = function(model, x, block_size)
+      image_f = function(model, x, block_size, batch_size)
 	 return reconstruct.image_tta(model, opt.tta_level,
-				      x, block_size)
+				      x, block_size, batch_size)
       end
    else
       scale_f = reconstruct.scale
@@ -64,7 +64,7 @@ local function convert_image(opt)
 	 error("Load Error: " .. model_path)
       end
       local t = sys.clock()
-      new_x = image_f(model, x, opt.crop_size)
+      new_x = image_f(model, x, opt.crop_size, opt.batch_size)
       new_x = alpha_util.composite(new_x, alpha)
       print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
    elseif opt.m == "scale" then
@@ -75,7 +75,7 @@ local function convert_image(opt)
       end
       local t = sys.clock()
       x = alpha_util.make_border(x, alpha, reconstruct.offset_size(model))
-      new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
+      new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.batch_size, opt.batch_size, opt.upsampling_filter)
       new_x = alpha_util.composite(new_x, alpha, model)
       print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
    elseif opt.m == "noise_scale" then
@@ -92,7 +92,7 @@ local function convert_image(opt)
 	 end
 	 local t = sys.clock()
 	 x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-	 new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
+	 new_x = scale_f(model, opt.scale, x, opt.crop_size, opt.batch_size, opt.upsampling_filter)
 	 new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
       else
@@ -109,8 +109,8 @@ local function convert_image(opt)
 	 end
 	 local t = sys.clock()
 	 x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-	 x = image_f(noise_model, x, opt.crop_size)
-	 new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
+	 x = image_f(noise_model, x, opt.crop_size, opt.batch_size)
+	 new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size, opt.upsampling_filter)
 	 new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
       end
@@ -125,13 +125,13 @@ local function convert_frames(opt)
    local noise_model = {}
    local scale_f, image_f
    if opt.tta == 1 then
-      scale_f = function(model, scale, x, block_size, upsampling_filter)
+      scale_f = function(model, scale, x, block_size, batch_size, upsampling_filter)
 	 return reconstruct.scale_tta(model, opt.tta_level,
-				      scale, x, block_size, upsampling_filter)
+				      scale, x, block_size, batch_size, upsampling_filter)
       end
-      image_f = function(model, x, block_size)
+      image_f = function(model, x, block_size, batch_size)
 	 return reconstruct.image_tta(model, opt.tta_level,
-				      x, block_size)
+				      x, block_size, batch_size)
       end
    else
       scale_f = reconstruct.scale
@@ -191,19 +191,19 @@ local function convert_frames(opt)
 	 local alpha = meta.alpha
 	 local new_x = nil
 	 if opt.m == "noise" then
-	    new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
+	    new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size, opt.batch_size)
 	    new_x = alpha_util.composite(new_x, alpha)
 	 elseif opt.m == "scale" then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
-	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.upsampling_filter)
+	    new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size, opt.upsampling_filter)
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 elseif opt.m == "noise_scale" then
 	    x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
 	    if noise_scale_model[opt.noise_level] then
-	       new_x = scale_f(noise_scale_model[opt.noise_level], opt.scale, x, opt.crop_size, upsampling_filter)
+	       new_x = scale_f(noise_scale_model[opt.noise_level], opt.scale, x, opt.crop_size, opt.batch_size, upsampling_filter)
 	    else
-	       x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
-	       new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, upsampling_filter)
+	       x = image_f(noise_model[opt.noise_level], x, opt.crop_size, opt.batch_size)
+	       new_x = scale_f(scale_model, opt.scale, x, opt.crop_size, opt.batch_size, upsampling_filter)
 	    end
 	    new_x = alpha_util.composite(new_x, alpha, scale_model)
 	 else
@@ -220,7 +220,6 @@ local function convert_frames(opt)
       end
    end
 end
-
 local function waifu2x()
    local cmd = torch.CmdLine()
    cmd:text()
@@ -235,6 +234,7 @@ local function waifu2x()
    cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
    cmd:option("-noise_level", 1, '(1|2|3)')
    cmd:option("-crop_size", 128, 'patch size per process')
+   cmd:option("-batch_size", 1, 'batch_size')
    cmd:option("-resume", 0, "skip existing files (0|1)")
    cmd:option("-thread", -1, "number of CPU threads")
    cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')

+ 11 - 10
web.lua

@@ -27,6 +27,7 @@ cmd:option("-port", 8812, 'listen port')
 cmd:option("-gpu", 1, 'Device ID')
 cmd:option("-upsampling_filter", "Box", 'Upsampling filter (for dev)')
 cmd:option("-crop_size", 128, 'patch size per process')
+cmd:option("-batch_size", 1, 'batch size')
 cmd:option("-thread", -1, 'number of CPU threads')
 local opt = cmd:parse(arg)
 cutorch.setDevice(opt.gpu)
@@ -148,23 +149,23 @@ local function convert(x, meta, options)
 	 end
 	 if options.method == "scale" then
 	    x = reconstruct.scale(art_scale2_model, 2.0, x,
-				  opt.crop_size, opt.upsampling_filter)
+				  opt.crop_size, opt.batch_size, opt.upsampling_filter)
 	    if alpha then
 	       if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
 		  alpha = reconstruct.scale(art_scale2_model, 2.0, alpha,
-					    opt.crop_size, opt.upsampling_filter)
+					    opt.crop_size, opt.batch_size, opt.upsampling_filter)
 		  image_loader.save_png(alpha_cache_file, alpha)
 	       end
 	    end
 	    cleanup_model(art_scale2_model)
 	 elseif options.method == "noise1" then
-	    x = reconstruct.image(art_noise1_model, x)
+	    x = reconstruct.image(art_noise1_model, x, opt.crop_size, opt.batch_size)
 	    cleanup_model(art_noise1_model)
 	 elseif options.method == "noise2" then
-	    x = reconstruct.image(art_noise2_model, x)
+	    x = reconstruct.image(art_noise2_model, x, opt.crop_size, opt.batch_size)
 	    cleanup_model(art_noise2_model)
 	 elseif options.method == "noise3" then
-	    x = reconstruct.image(art_noise3_model, x)
+	    x = reconstruct.image(art_noise3_model, x, opt.crop_size, opt.batch_size)
 	    cleanup_model(art_noise3_model)
 	 end
       else -- photo
@@ -173,23 +174,23 @@ local function convert(x, meta, options)
 	 end
 	 if options.method == "scale" then
 	    x = reconstruct.scale(photo_scale2_model, 2.0, x,
-				  opt.crop_size, opt.upsampling_filter)
+				  opt.crop_size, opt.batch_size, opt.upsampling_filter)
 	    if alpha then
 	       if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then
 		  alpha = reconstruct.scale(photo_scale2_model, 2.0, alpha,
-					    opt.crop_size, opt.upsampling_filter)
+					    opt.crop_size, opt.batch_size, opt.upsampling_filter)
 		  image_loader.save_png(alpha_cache_file, alpha)
 	       end
 	    end
 	    cleanup_model(photo_scale2_model)
 	 elseif options.method == "noise1" then
-	    x = reconstruct.image(photo_noise1_model, x)
+	    x = reconstruct.image(photo_noise1_model, x, opt.crop_size, opt.batch_size)
 	    cleanup_model(photo_noise1_model)
 	 elseif options.method == "noise2" then
-	    x = reconstruct.image(photo_noise2_model, x)
+	    x = reconstruct.image(photo_noise2_model, x, opt.crop_size, opt.batch_size)
 	    cleanup_model(photo_noise2_model)
 	 elseif options.method == "noise3" then
-	    x = reconstruct.image(photo_noise3_model, x)
+	    x = reconstruct.image(photo_noise3_model, x, opt.crop_size, opt.batch_size)
 	    cleanup_model(photo_noise3_model)
 	 end
       end