|
@@ -109,8 +109,27 @@ local function padding_params(x, model, block_size)
|
|
|
p.pad_w2 = (w - input_offset) - p.x_w
|
|
|
return p
|
|
|
end
|
|
|
+local function find_valid_block_size(model, block_size)
|
|
|
+ if model.w2nn_input_size ~= nil then
|
|
|
+ return model.w2nn_input_size
|
|
|
+ elseif model.w2nn_valid_input_size ~= nil then
|
|
|
+ local best_size = 0
|
|
|
+ local best_diff = 10000
|
|
|
+ for i = 1, #model.w2nn_valid_input_size do
|
|
|
+ local diff = math.abs(model.w2nn_valid_input_size[i] - block_size)
|
|
|
+ if diff < best_diff then
|
|
|
+ best_size = model.w2nn_valid_input_size[i]
|
|
|
+ best_diff = diff
|
|
|
+ end
|
|
|
+ end
|
|
|
+ assert(best_size > 0)
|
|
|
+ return best_size
|
|
|
+ else
|
|
|
+ return block_size
|
|
|
+ end
|
|
|
+end
|
|
|
function reconstruct.image_y(model, x, offset, block_size, batch_size)
|
|
|
- block_size = block_size or 128
|
|
|
+ block_size = find_valid_block_size(model, block_size or 128)
|
|
|
local p = padding_params(x, model, block_size)
|
|
|
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
|
|
|
x = x:cuda()
|
|
@@ -126,7 +145,7 @@ function reconstruct.image_y(model, x, offset, block_size, batch_size)
|
|
|
return output
|
|
|
end
|
|
|
function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
|
|
|
- block_size = block_size or 128
|
|
|
+ block_size = find_valid_block_size(model, block_size or 128)
|
|
|
local x_lanczos
|
|
|
if reconstruct.has_resize(model) then
|
|
|
x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
|
|
@@ -153,7 +172,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
|
|
|
return output
|
|
|
end
|
|
|
function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
|
|
|
- block_size = block_size or 128
|
|
|
+ block_size = find_valid_block_size(model, block_size or 128)
|
|
|
local p = padding_params(x, model, block_size)
|
|
|
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
|
|
|
if p.x_w * p.x_h > 2048*2048 then
|
|
@@ -168,7 +187,7 @@ function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
|
|
|
return output
|
|
|
end
|
|
|
function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
|
|
|
- block_size = block_size or 128
|
|
|
+ block_size = find_valid_block_size(model, block_size or 128)
|
|
|
if not reconstruct.has_resize(model) then
|
|
|
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
|
|
|
end
|
|
@@ -186,9 +205,6 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
|
|
|
return output
|
|
|
end
|
|
|
function reconstruct.image(model, x, block_size)
|
|
|
- if model.w2nn_input_size then
|
|
|
- block_size = model.w2nn_input_size
|
|
|
- end
|
|
|
local i2rgb = false
|
|
|
if x:size(1) == 1 then
|
|
|
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
|
@@ -211,9 +227,6 @@ function reconstruct.image(model, x, block_size)
|
|
|
return x
|
|
|
end
|
|
|
function reconstruct.scale(model, scale, x, block_size)
|
|
|
- if model.w2nn_input_size then
|
|
|
- block_size = model.w2nn_input_size
|
|
|
- end
|
|
|
local i2rgb = false
|
|
|
if x:size(1) == 1 then
|
|
|
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|