|
@@ -209,63 +209,90 @@ function reconstruct.scale(model, scale, x, block_size, upsampling_filter)
|
|
end
|
|
end
|
|
return x
|
|
return x
|
|
end
|
|
end
|
|
-local function tta(f, model, x, block_size)
|
|
|
|
|
|
+local function tr_f(a)
|
|
|
|
+ return a:transpose(2, 3):contiguous()
|
|
|
|
+end
|
|
|
|
+local function itr_f(a)
|
|
|
|
+ return a:transpose(2, 3):contiguous()
|
|
|
|
+end
|
|
|
|
+local augmented_patterns = {
|
|
|
|
+ {
|
|
|
|
+ forward = function (a) return a end,
|
|
|
|
+ backward = function (a) return a end
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ forward = function (a) return image.hflip(a) end,
|
|
|
|
+ backward = function (a) return image.hflip(a) end
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ forward = function (a) return image.vflip(a) end,
|
|
|
|
+ backward = function (a) return image.vflip(a) end
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ forward = function (a) return image.hflip(image.vflip(a)) end,
|
|
|
|
+ backward = function (a) return image.vflip(image.hflip(a)) end
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ forward = function (a) return tr_f(a) end,
|
|
|
|
+ backward = function (a) return itr_f(a) end
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ forward = function (a) return image.hflip(tr_f(a)) end,
|
|
|
|
+ backward = function (a) return itr_f(image.hflip(a)) end
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ forward = function (a) return image.vflip(tr_f(a)) end,
|
|
|
|
+ backward = function (a) return itr_f(image.vflip(a)) end
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ forward = function (a) return image.hflip(image.vflip(tr_f(a))) end,
|
|
|
|
+ backward = function (a) return itr_f(image.vflip(image.hflip(a))) end
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+local function get_augmented_patterns(n)
|
|
|
|
+ if n == 2 then
|
|
|
|
+ return {augmented_patterns[1], augmented_patterns[5]}
|
|
|
|
+ elseif n == 4 then
|
|
|
|
+ return {augmented_patterns[1], augmented_patterns[5],
|
|
|
|
+ augmented_patterns[2], augmented_patterns[7]}
|
|
|
|
+ elseif n == 8 then
|
|
|
|
+ return augmented_patterns
|
|
|
|
+ else
|
|
|
|
+ error("unsupported TTA level: " .. n)
|
|
|
|
+ end
|
|
|
|
+end
|
|
|
|
+local function tta(f, n, model, x, block_size)
|
|
local average = nil
|
|
local average = nil
|
|
local offset = reconstruct.offset_size(model)
|
|
local offset = reconstruct.offset_size(model)
|
|
- for i = 1, 4 do
|
|
|
|
- local flip_f, iflip_f
|
|
|
|
- if i == 1 then
|
|
|
|
- flip_f = function (a) return a end
|
|
|
|
- iflip_f = function (a) return a end
|
|
|
|
- elseif i == 2 then
|
|
|
|
- flip_f = image.vflip
|
|
|
|
- iflip_f = image.vflip
|
|
|
|
- elseif i == 3 then
|
|
|
|
- flip_f = image.hflip
|
|
|
|
- iflip_f = image.hflip
|
|
|
|
- elseif i == 4 then
|
|
|
|
- flip_f = function (a) return image.hflip(image.vflip(a)) end
|
|
|
|
- iflip_f = function (a) return image.vflip(image.hflip(a)) end
|
|
|
|
- end
|
|
|
|
- for j = 1, 2 do
|
|
|
|
- local tr_f, itr_f
|
|
|
|
- if j == 1 then
|
|
|
|
- tr_f = function (a) return a end
|
|
|
|
- itr_f = function (a) return a end
|
|
|
|
- elseif j == 2 then
|
|
|
|
- tr_f = function(a) return a:transpose(2, 3):contiguous() end
|
|
|
|
- itr_f = function(a) return a:transpose(2, 3):contiguous() end
|
|
|
|
- end
|
|
|
|
- local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
|
|
|
|
- offset, block_size)))
|
|
|
|
- if not average then
|
|
|
|
- average = out
|
|
|
|
- else
|
|
|
|
- average:add(out)
|
|
|
|
- end
|
|
|
|
|
|
+ local augments = get_augmented_patterns(n)
|
|
|
|
+ for i = 1, #augments do
|
|
|
|
+ local out = augments[i].backward(f(model, augments[i].forward(x), offset, block_size))
|
|
|
|
+ if not average then
|
|
|
|
+ average = out
|
|
|
|
+ else
|
|
|
|
+ average:add(out)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
- return average:div(8.0)
|
|
|
|
|
|
+ return average:div(#augments)
|
|
end
|
|
end
|
|
-function reconstruct.image_tta(model, x, block_size)
|
|
|
|
|
|
+function reconstruct.image_tta(model, n, x, block_size)
|
|
if reconstruct.is_rgb(model) then
|
|
if reconstruct.is_rgb(model) then
|
|
- return tta(reconstruct.image_rgb, model, x, block_size)
|
|
|
|
|
|
+ return tta(reconstruct.image_rgb, n, model, x, block_size)
|
|
else
|
|
else
|
|
- return tta(reconstruct.image_y, model, x, block_size)
|
|
|
|
|
|
+ return tta(reconstruct.image_y, n, model, x, block_size)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
-function reconstruct.scale_tta(model, scale, x, block_size, upsampling_filter)
|
|
|
|
|
|
+function reconstruct.scale_tta(model, n, scale, x, block_size, upsampling_filter)
|
|
if reconstruct.is_rgb(model) then
|
|
if reconstruct.is_rgb(model) then
|
|
local f = function (model, x, offset, block_size)
|
|
local f = function (model, x, offset, block_size)
|
|
return reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
|
|
return reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
|
|
end
|
|
end
|
|
- return tta(f, model, x, block_size)
|
|
|
|
-
|
|
|
|
|
|
+ return tta(f, n, model, x, block_size)
|
|
else
|
|
else
|
|
local f = function (model, x, offset, block_size)
|
|
local f = function (model, x, offset, block_size)
|
|
return reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
|
|
return reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
|
|
end
|
|
end
|
|
- return tta(f, model, x, block_size)
|
|
|
|
|
|
+ return tta(f, n, model, x, block_size)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
|