Explorar o código

tuning a little

nagadomi %!s(int64=9) %!d(string=hai) anos
pai
achega
14330e919c
Modificáronse 1 ficheiros con 12 adicións e 25 borrados
  1. 12 25
      lib/reconstruct.lua

+ 12 - 25
lib/reconstruct.lua

@@ -98,22 +98,17 @@ end
 function reconstruct.image_y(model, x, offset, block_size, batch_size)
 function reconstruct.image_y(model, x, offset, block_size, batch_size)
    block_size = block_size or 128
    block_size = block_size or 128
    local p = padding_params(x, model, block_size)
    local p = padding_params(x, model, block_size)
-   x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
+   x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
+   x = x:cuda()
+   x = image.rgb2yuv(x)
    local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
    local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
-
    x = iproc.crop(x, p.pad_w1, p.pad_h1, p.pad_w1 + p.x_w, p.pad_h1 + p.x_h)
    x = iproc.crop(x, p.pad_w1, p.pad_h1, p.pad_w1 + p.x_w, p.pad_h1 + p.x_h)
-   y = iproc.crop(y, 0, 0, p.x_w, p.x_h)
-
-   y[torch.lt(y, 0)] = 0
-   y[torch.gt(y, 1)] = 1
+   y = iproc.crop(y, 0, 0, p.x_w, p.x_h):clamp(0, 1)
    x[1]:copy(y)
    x[1]:copy(y)
-   local output = image.yuv2rgb(x)
-   output[torch.lt(output, 0)] = 0
-   output[torch.gt(output, 1)] = 1
+   local output = image.yuv2rgb(x):clamp(0, 1):float()
    x = nil
    x = nil
    y = nil
    y = nil
    collectgarbage()
    collectgarbage()
-   
    return output
    return output
 end
 end
 function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
 function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
@@ -129,21 +124,18 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
    if p.x_w * p.x_h > 2048*2048 then
    if p.x_w * p.x_h > 2048*2048 then
       collectgarbage()
       collectgarbage()
    end
    end
-   x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
+   x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
+   x = x:cuda()
+   x = image.rgb2yuv(x)
    x_lanczos = image.rgb2yuv(x_lanczos)
    x_lanczos = image.rgb2yuv(x_lanczos)
    local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
    local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
-   y = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
-   y[torch.lt(y, 0)] = 0
-   y[torch.gt(y, 1)] = 1
+   y = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale):clamp(0, 1)
    x_lanczos[1]:copy(y)
    x_lanczos[1]:copy(y)
-   local output = image.yuv2rgb(x_lanczos)
-   output[torch.lt(output, 0)] = 0
-   output[torch.gt(output, 1)] = 1
+   local output = image.yuv2rgb(x_lanczos:cuda()):clamp(0, 1):float()
    x = nil
    x = nil
    x_lanczos = nil
    x_lanczos = nil
    y = nil
    y = nil
    collectgarbage()
    collectgarbage()
-   
    return output
    return output
 end
 end
 function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
 function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
@@ -154,9 +146,7 @@ function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
       collectgarbage()
       collectgarbage()
    end
    end
    local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
    local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
-   local output = iproc.crop(y, 0, 0, p.x_w, p.x_h)
-   output[torch.lt(output, 0)] = 0
-   output[torch.gt(output, 1)] = 1
+   local output = iproc.crop(y, 0, 0, p.x_w, p.x_h):clamp(0, 1)
    x = nil
    x = nil
    y = nil
    y = nil
    collectgarbage()
    collectgarbage()
@@ -175,13 +165,10 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
    end
    end
    local y
    local y
    y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
    y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
-   local output = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
-   output[torch.lt(output, 0)] = 0
-   output[torch.gt(output, 1)] = 1
+   local output = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale):clamp(0, 1)
    x = nil
    x = nil
    y = nil
    y = nil
    collectgarbage()
    collectgarbage()
-
    return output
    return output
 end
 end
 function reconstruct.image(model, x, block_size)
 function reconstruct.image(model, x, block_size)