nagadomi 9 лет назад
Родитель
Сommit
57e0f52b41
5 измененных файлов с 66 добавлено и 33 удалено
  1. 3 0
      lib/alpha_util.lua
  2. 18 8
      lib/image_loader.lua
  3. 42 22
      lib/reconstruct.lua
  4. 2 2
      waifu2x.lua
  5. 1 1
      web.lua

+ 3 - 0
lib/alpha_util.lua

@@ -36,6 +36,9 @@ function alpha_util.make_border(rgb, alpha, offset)
       mask = mask_weight:clone()
       mask[torch.gt(mask_weight, 0.0)] = 1
       mask_nega = (mask - 1):abs():byte()
+      if border:size(2) * border:size(3) > 1024*1024 then
+	 collectgarbage()
+      end
    end
    rgb[torch.gt(rgb, 1.0)] = 1.0
    rgb[torch.lt(rgb, 0.0)] = 0.0

+ 18 - 8
lib/image_loader.lua

@@ -9,19 +9,30 @@ local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
 local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
 local background_color = 0.5
 
-function image_loader.encode_png(rgb, depth)
+function image_loader.encode_png(rgb, depth, inplace)
+   if inplace == nil then
+      inplace = false
+   end
    depth = depth or 8
    rgb = iproc.byte2float(rgb)
    if depth < 16 then
-      rgb = rgb:clone():add(clip_eps8)
+      if inplace then
+	 rgb:add(clip_eps8)
+      else
+	 rgb = rgb:clone():add(clip_eps8)
+      end
       rgb[torch.lt(rgb, 0.0)] = 0.0
       rgb[torch.gt(rgb, 1.0)] = 1.0
-      rgb = rgb:mul(255):long():float():div(255)
+      rgb = rgb:mul(255):floor():div(255)
    else
-      rgb = rgb:clone():add(clip_eps16)
+      if inplace then
+	 rgb:add(clip_eps16)
+      else
+	 rgb = rgb:clone():add(clip_eps16)
+      end
       rgb[torch.lt(rgb, 0.0)] = 0.0
       rgb[torch.gt(rgb, 1.0)] = 1.0
-      rgb = rgb:mul(65535):long():float():div(65535)
+      rgb = rgb:mul(65535):floor():div(65535)
    end
    local im
    if rgb:size(1) == 4 then -- RGBA
@@ -34,9 +45,8 @@ function image_loader.encode_png(rgb, depth)
    end
    return im:depth(depth):format("PNG"):toString(9)
 end
-function image_loader.save_png(filename, rgb, depth)
-   depth = depth or 8
-   local blob = image_loader.encode_png(rgb, depth)
+function image_loader.save_png(filename, rgb, depth, inplace)
+   local blob = image_loader.encode_png(rgb, depth, inplace)
    local fp = io.open(filename, "wb")
    if not fp then
       error("IO error: " .. filename)

+ 42 - 22
lib/reconstruct.lua

@@ -16,7 +16,7 @@ local function reconstruct_y(model, x, offset, block_size)
 			   {i, i + block_size - 1},
 			   {j, j + block_size - 1}}
 	    input:copy(x[index])
-	    local output = model:forward(input):float():view(1, output_size, output_size)
+	    local output = model:forward(input):view(1, output_size, output_size)
 	    local output_index = {{},
 				  {i + offset, offset + i + output_size - 1},
 				  {offset + j, offset + j + output_size - 1}}
@@ -38,7 +38,7 @@ local function reconstruct_rgb(model, x, offset, block_size)
 			   {i, i + block_size - 1},
 			   {j, j + block_size - 1}}
 	    input:copy(x[index])
-	    local output = model:forward(input):float():view(3, output_size, output_size)
+	    local output = model:forward(input):view(3, output_size, output_size)
 	    local output_index = {{},
 				  {i + offset, offset + i + output_size - 1},
 				  {offset + j, offset + j + output_size - 1}}
@@ -89,16 +89,18 @@ function reconstruct.image_y(model, x, offset, block_size)
    local pad_w1 = offset
    local pad_h2 = (h - offset) - x:size(2)
    local pad_w2 = (w - offset) - x:size(3)
-   local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
-   local y = reconstruct_y(model, yuv[1], offset, block_size)
+   x = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
+   local y = reconstruct_y(model, x[1], offset, block_size)
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
-   yuv[1]:copy(y)
-   local output = image.yuv2rgb(iproc.crop(yuv,
+   x[1]:copy(y)
+   local output = image.yuv2rgb(iproc.crop(x,
 					   pad_w1, pad_h1,
-					   yuv:size(3) - pad_w2, yuv:size(2) - pad_h2))
+					   x:size(3) - pad_w2, x:size(2) - pad_h2))
    output[torch.lt(output, 0)] = 0
    output[torch.gt(output, 1)] = 1
+   x = nil
+   y = nil
    collectgarbage()
    
    return output
@@ -107,7 +109,9 @@ function reconstruct.scale_y(model, scale, x, offset, block_size)
    block_size = block_size or 128
    local x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
    x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
-
+   if x:size(2) * x:size(3) > 2048*2048 then
+      collectgarbage()
+   end
    local output_size = block_size - offset * 2
    local h_blocks = math.floor(x:size(2) / output_size) +
       ((x:size(2) % output_size == 0 and 0) or 1)
@@ -120,17 +124,20 @@ function reconstruct.scale_y(model, scale, x, offset, block_size)
    local pad_w1 = offset
    local pad_h2 = (h - offset) - x:size(2)
    local pad_w2 = (w - offset) - x:size(3)
-   local yuv_nn = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
-   local yuv_lanczos = image.rgb2yuv(iproc.padding(x_lanczos, pad_w1, pad_w2, pad_h1, pad_h2))
-   local y = reconstruct_y(model, yuv_nn[1], offset, block_size)
+   x = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
+   x_lanczos = image.rgb2yuv(iproc.padding(x_lanczos, pad_w1, pad_w2, pad_h1, pad_h2))
+   local y = reconstruct_y(model, x[1], offset, block_size)
    y[torch.lt(y, 0)] = 0
    y[torch.gt(y, 1)] = 1
-   yuv_lanczos[1]:copy(y)
-   local output = image.yuv2rgb(iproc.crop(yuv_lanczos,
+   x_lanczos[1]:copy(y)
+   local output = image.yuv2rgb(iproc.crop(x_lanczos,
 					   pad_w1, pad_h1,
-					   yuv_lanczos:size(3) - pad_w2, yuv_lanczos:size(2) - pad_h2))
+					   x_lanczos:size(3) - pad_w2, x_lanczos:size(2) - pad_h2))
    output[torch.lt(output, 0)] = 0
    output[torch.gt(output, 1)] = 1
+   x = nil
+   x_lanczos = nil
+   y = nil
    collectgarbage()
    
    return output
@@ -149,21 +156,29 @@ function reconstruct.image_rgb(model, x, offset, block_size)
    local pad_w1 = offset
    local pad_h2 = (h - offset) - x:size(2)
    local pad_w2 = (w - offset) - x:size(3)
-   local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
-   local y = reconstruct_rgb(model, input, offset, block_size)
+
+   x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
+   if x:size(2) * x:size(3) > 2048*2048 then
+      collectgarbage()
+   end
+   local y = reconstruct_rgb(model, x, offset, block_size)
    local output = iproc.crop(y,
 			     pad_w1, pad_h1,
 			     y:size(3) - pad_w2, y:size(2) - pad_h2)
-   collectgarbage()
    output[torch.lt(output, 0)] = 0
    output[torch.gt(output, 1)] = 1
-   
+   x = nil
+   y = nil
+   collectgarbage()
+
    return output
 end
 function reconstruct.scale_rgb(model, scale, x, offset, block_size)
    block_size = block_size or 128
    x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
-
+   if x:size(2) * x:size(3) > 2048*2048 then
+      collectgarbage()
+   end
    local output_size = block_size - offset * 2
    local h_blocks = math.floor(x:size(2) / output_size) +
       ((x:size(2) % output_size == 0 and 0) or 1)
@@ -176,15 +191,20 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
    local pad_w1 = offset
    local pad_h2 = (h - offset) - x:size(2)
    local pad_w2 = (w - offset) - x:size(3)
-   local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
-   local y = reconstruct_rgb(model, input, offset, block_size)
+   x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
+   if x:size(2) * x:size(3) > 2048*2048 then
+      collectgarbage()
+   end
+   local y = reconstruct_rgb(model, x, offset, block_size)
    local output = iproc.crop(y,
 			     pad_w1, pad_h1,
 			     y:size(3) - pad_w2, y:size(2) - pad_h2)
    output[torch.lt(output, 0)] = 0
    output[torch.gt(output, 1)] = 1
+   x = nil
+   y = nil
    collectgarbage()
-   
+
    return output
 end
 

+ 2 - 2
waifu2x.lua

@@ -65,7 +65,7 @@ local function convert_image(opt)
    else
       error("undefined method:" .. opt.method)
    end
-   image_loader.save_png(opt.o, new_x, opt.depth)
+   image_loader.save_png(opt.o, new_x, opt.depth, true)
    print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
 end
 local function convert_frames(opt)
@@ -162,7 +162,7 @@ local function convert_frames(opt)
 	 else
 	    output = string.format(opt.o, i)
 	 end
-	 image_loader.save_png(output, new_x, opt.depth)
+	 image_loader.save_png(output, new_x, opt.depth, true)
 	 xlua.progress(i, #lines)
 	 if i % 10 == 0 then
 	    collectgarbage()

+ 1 - 1
web.lua

@@ -259,7 +259,7 @@ function APIHandler:post()
       else
 	 name = uuid() .. ".png"
       end
-      local blob = image_loader.encode_png(alpha_util.composite(x, alpha))
+      local blob = image_loader.encode_png(alpha_util.composite(x, alpha), 8, true)
 
       self:set_header("Content-Length", string.format("%d", #blob))
       if download > 0 then