|
@@ -16,10 +16,19 @@ local function random_half(src, p)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
local function crop_if_large(src, max_size)
|
|
local function crop_if_large(src, max_size)
|
|
|
|
+ local tries = 4
|
|
if src:size(2) > max_size and src:size(3) > max_size then
|
|
if src:size(2) > max_size and src:size(3) > max_size then
|
|
- local yi = torch.random(0, src:size(2) - max_size)
|
|
|
|
- local xi = torch.random(0, src:size(3) - max_size)
|
|
|
|
- return iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
|
|
|
|
|
+ local rect
|
|
|
|
+ for i = 1, tries do
|
|
|
|
+ local yi = torch.random(0, src:size(2) - max_size)
|
|
|
|
+ local xi = torch.random(0, src:size(3) - max_size)
|
|
|
|
+ rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
|
|
|
+ -- ignore simple background
|
|
|
|
+ if rect:float():std() >= 0 then
|
|
|
|
+ break
|
|
|
|
+ end
|
|
|
|
+ end
|
|
|
|
+ return rect
|
|
else
|
|
else
|
|
return src
|
|
return src
|
|
end
|
|
end
|
|
@@ -29,7 +38,7 @@ local function preprocess(src, crop_size, options)
|
|
if options.random_half then
|
|
if options.random_half then
|
|
dest = random_half(dest)
|
|
dest = random_half(dest)
|
|
end
|
|
end
|
|
- dest = crop_if_large(dest, math.max(crop_size * 4, 512))
|
|
|
|
|
|
+ dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
|
dest = data_augmentation.flip(dest)
|
|
dest = data_augmentation.flip(dest)
|
|
if options.color_noise then
|
|
if options.color_noise then
|
|
dest = data_augmentation.color_noise(dest)
|
|
dest = data_augmentation.color_noise(dest)
|
|
@@ -52,7 +61,9 @@ local function active_cropping(x, y, size, p, tries)
|
|
return xc, yc
|
|
return xc, yc
|
|
else
|
|
else
|
|
local samples = {}
|
|
local samples = {}
|
|
- local sum_mse = 0
|
|
|
|
|
|
+ local best_se = 0.0
|
|
|
|
+ local best_xc, best_yc
|
|
|
|
+ local m = torch.FloatTensor(x:size(1), size, size)
|
|
for i = 1, tries do
|
|
for i = 1, tries do
|
|
local xi = torch.random(0, y:size(3) - (size + 1))
|
|
local xi = torch.random(0, y:size(3) - (size + 1))
|
|
local yi = torch.random(0, y:size(2) - (size + 1))
|
|
local yi = torch.random(0, y:size(2) - (size + 1))
|
|
@@ -60,17 +71,14 @@ local function active_cropping(x, y, size, p, tries)
|
|
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
|
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
|
local xcf = iproc.byte2float(xc)
|
|
local xcf = iproc.byte2float(xc)
|
|
local ycf = iproc.byte2float(yc)
|
|
local ycf = iproc.byte2float(yc)
|
|
- local mse = (xcf - ycf):pow(2):mean()
|
|
|
|
- sum_mse = sum_mse + mse
|
|
|
|
- table.insert(samples, {xc = xc, yc = yc, mse = mse})
|
|
|
|
- end
|
|
|
|
- if sum_mse > 0 then
|
|
|
|
- table.sort(samples,
|
|
|
|
- function (a, b)
|
|
|
|
- return a.mse > b.mse
|
|
|
|
- end)
|
|
|
|
|
|
+ local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
|
|
|
|
+ if se >= best_se then
|
|
|
|
+ best_xc = xcf
|
|
|
|
+ best_yc = ycf
|
|
|
|
+ best_se = se
|
|
|
|
+ end
|
|
end
|
|
end
|
|
- return samples[1].xc, samples[1].yc
|
|
|
|
|
|
+ return best_xc, best_yc
|
|
end
|
|
end
|
|
end
|
|
end
|
|
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|
@@ -83,6 +91,7 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|
"SincFast", -- 0.014095824314306
|
|
"SincFast", -- 0.014095824314306
|
|
"Jinc", -- 0.014244299255442
|
|
"Jinc", -- 0.014244299255442
|
|
}
|
|
}
|
|
|
|
+ local unstable_region_offset = 8
|
|
local downscale_filter = filters[torch.random(1, #filters)]
|
|
local downscale_filter = filters[torch.random(1, #filters)]
|
|
local y = preprocess(src, size, options)
|
|
local y = preprocess(src, size, options)
|
|
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
|
|
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
|
|
@@ -90,6 +99,13 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
|
|
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
|
|
y:size(2) * down_scale, downscale_filter),
|
|
y:size(2) * down_scale, downscale_filter),
|
|
y:size(3), y:size(2))
|
|
y:size(3), y:size(2))
|
|
|
|
+ x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
|
|
|
+ x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
|
|
|
+ y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
|
|
|
+ y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
|
|
|
+ assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
|
|
|
+ assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
|
|
|
+
|
|
local batch = {}
|
|
local batch = {}
|
|
for i = 1, n do
|
|
for i = 1, n do
|
|
local xc, yc = active_cropping(x, y,
|
|
local xc, yc = active_cropping(x, y,
|
|
@@ -108,8 +124,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|
return batch
|
|
return batch
|
|
end
|
|
end
|
|
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
|
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
|
|
|
+ local unstable_region_offset = 8
|
|
local y = preprocess(src, size, options)
|
|
local y = preprocess(src, size, options)
|
|
local x = y
|
|
local x = y
|
|
|
|
+
|
|
for i = 1, #quality do
|
|
for i = 1, #quality do
|
|
x = gm.Image(x, "RGB", "DHW")
|
|
x = gm.Image(x, "RGB", "DHW")
|
|
x:format("jpeg")
|
|
x:format("jpeg")
|
|
@@ -122,7 +140,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
|
x:fromBlob(blob, len)
|
|
x:fromBlob(blob, len)
|
|
x = x:toTensor("byte", "RGB", "DHW")
|
|
x = x:toTensor("byte", "RGB", "DHW")
|
|
end
|
|
end
|
|
- -- TODO: use shift_1px after compression?
|
|
|
|
|
|
+ x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
|
|
|
+ x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
|
|
|
+ y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
|
|
|
+ y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
|
|
|
+ assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
|
|
|
+ assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
|
|
|
|
|
local batch = {}
|
|
local batch = {}
|
|
for i = 1, n do
|
|
for i = 1, n do
|
|
@@ -152,7 +175,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
|
end
|
|
end
|
|
elseif level == 2 then
|
|
elseif level == 2 then
|
|
local r = torch.uniform()
|
|
local r = torch.uniform()
|
|
- if torch.uniform() > 0.8 then
|
|
|
|
|
|
+ if torch.uniform() > 0.9 then
|
|
return pairwise_transform.jpeg_(src, {},
|
|
return pairwise_transform.jpeg_(src, {},
|
|
size, offset, n, options)
|
|
size, offset, n, options)
|
|
else
|
|
else
|