|
@@ -2,7 +2,8 @@ require 'image'
|
|
|
local iproc = require 'iproc'
|
|
|
local srcnn = require 'srcnn'
|
|
|
|
|
|
-local function reconstruct_nn(model, x, inner_scale, offset, block_size)
|
|
|
+local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_size)
|
|
|
+ batch_size = batch_size or 1
|
|
|
if x:dim() == 2 then
|
|
|
x = x:reshape(1, x:size(1), x:size(2))
|
|
|
end
|
|
@@ -12,24 +13,46 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size)
|
|
|
local output_block_size = block_size
|
|
|
local output_size = output_block_size - offset * 2
|
|
|
local output_size_in_input = input_block_size - math.ceil(offset / inner_scale) * 2
|
|
|
- local input = torch.CudaTensor(1, ch, input_block_size, input_block_size)
|
|
|
+ local input_indexes = {}
|
|
|
+ local output_indexes = {}
|
|
|
for i = 1, x:size(2), output_size_in_input do
|
|
|
for j = 1, x:size(3), output_size_in_input do
|
|
|
if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then
|
|
|
local index = {{},
|
|
|
{i, i + input_block_size - 1},
|
|
|
{j, j + input_block_size - 1}}
|
|
|
- input:copy(x[index])
|
|
|
- local output = model:forward(input)
|
|
|
- output = output:view(ch, output_size, output_size)
|
|
|
local ii = (i - 1) * inner_scale + 1
|
|
|
local jj = (j - 1) * inner_scale + 1
|
|
|
local output_index = {{}, { ii , ii + output_size - 1 },
|
|
|
{ jj, jj + output_size - 1}}
|
|
|
- new_x[output_index]:copy(output)
|
|
|
+ table.insert(input_indexes, index)
|
|
|
+ table.insert(output_indexes, output_index)
|
|
|
end
|
|
|
end
|
|
|
end
|
|
|
+ local input = torch.Tensor(batch_size, ch, input_block_size, input_block_size)
|
|
|
+ local input_cuda = torch.CudaTensor(batch_size, ch, input_block_size, input_block_size)
|
|
|
+ for i = 1, #input_indexes, batch_size do
|
|
|
+ local c = 0
|
|
|
+ local output
|
|
|
+ for j = 0, batch_size - 1 do
|
|
|
+ if i + j > #input_indexes then
|
|
|
+ break
|
|
|
+ end
|
|
|
+ input[j+1]:copy(x[input_indexes[i + j]])
|
|
|
+ c = c + 1
|
|
|
+ end
|
|
|
+ input_cuda:copy(input)
|
|
|
+ if c == batch_size then
|
|
|
+ output = model:forward(input_cuda)
|
|
|
+ else
|
|
|
+ output = model:forward(input_cuda:narrow(1, 1, c))
|
|
|
+ end
|
|
|
+ --output = output:view(batch_size, ch, output_size, output_size)
|
|
|
+ for j = 0, c - 1 do
|
|
|
+ new_x[output_indexes[i + j]]:copy(output[j+1])
|
|
|
+ end
|
|
|
+ end
|
|
|
return new_x
|
|
|
end
|
|
|
local reconstruct = {}
|
|
@@ -72,11 +95,11 @@ local function padding_params(x, model, block_size)
|
|
|
p.pad_w2 = (w - input_offset) - p.x_w
|
|
|
return p
|
|
|
end
|
|
|
-function reconstruct.image_y(model, x, offset, block_size)
|
|
|
+function reconstruct.image_y(model, x, offset, block_size, batch_size)
|
|
|
block_size = block_size or 128
|
|
|
local p = padding_params(x, model, block_size)
|
|
|
x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
|
|
|
- local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size)
|
|
|
+ local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
|
|
|
x = iproc.crop(x, p.pad_w1, p.pad_w2, p.pad_w1 + p.x_w, p.pad_w2 + p.x_h)
|
|
|
y = iproc.crop(y, 0, 0, p.x_w, p.x_h)
|
|
|
y[torch.lt(y, 0)] = 0
|
|
@@ -91,7 +114,7 @@ function reconstruct.image_y(model, x, offset, block_size)
|
|
|
|
|
|
return output
|
|
|
end
|
|
|
-function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
|
|
|
+function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size, upsampling_filter)
|
|
|
upsampling_filter = upsampling_filter or "Box"
|
|
|
block_size = block_size or 128
|
|
|
local x_lanczos
|
|
@@ -107,7 +130,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_fil
|
|
|
end
|
|
|
x = image.rgb2yuv(iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2))
|
|
|
x_lanczos = image.rgb2yuv(x_lanczos)
|
|
|
- local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size)
|
|
|
+ local y = reconstruct_nn(model, x[1], p.inner_scale, offset, block_size, batch_size)
|
|
|
y = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
|
|
|
y[torch.lt(y, 0)] = 0
|
|
|
y[torch.gt(y, 1)] = 1
|
|
@@ -122,14 +145,14 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_fil
|
|
|
|
|
|
return output
|
|
|
end
|
|
|
-function reconstruct.image_rgb(model, x, offset, block_size)
|
|
|
+function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
|
|
|
block_size = block_size or 128
|
|
|
local p = padding_params(x, model, block_size)
|
|
|
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
|
|
|
if p.x_w * p.x_h > 2048*2048 then
|
|
|
collectgarbage()
|
|
|
end
|
|
|
- local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size)
|
|
|
+ local y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
|
|
|
local output = iproc.crop(y, 0, 0, p.x_w, p.x_h)
|
|
|
output[torch.lt(output, 0)] = 0
|
|
|
output[torch.gt(output, 1)] = 1
|
|
@@ -139,7 +162,7 @@ function reconstruct.image_rgb(model, x, offset, block_size)
|
|
|
|
|
|
return output
|
|
|
end
|
|
|
-function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
|
|
|
+function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size, upsampling_filter)
|
|
|
upsampling_filter = upsampling_filter or "Box"
|
|
|
block_size = block_size or 128
|
|
|
if not reconstruct.has_resize(model) then
|
|
@@ -151,7 +174,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_f
|
|
|
collectgarbage()
|
|
|
end
|
|
|
local y
|
|
|
- y = reconstruct_nn(model, x, p.inner_scale, offset, block_size)
|
|
|
+ y = reconstruct_nn(model, x, p.inner_scale, offset, block_size, batch_size)
|
|
|
local output = iproc.crop(y, 0, 0, p.x_w * p.inner_scale, p.x_h * p.inner_scale)
|
|
|
output[torch.lt(output, 0)] = 0
|
|
|
output[torch.gt(output, 1)] = 1
|