|
@@ -13,6 +13,21 @@ local function pcacov(x)
|
|
|
local ce, cv = torch.symeig(c, 'V')
|
|
|
return ce, cv
|
|
|
end
|
|
|
+
|
|
|
+function random_rect_size(rect_min, rect_max)
|
|
|
+ local r = torch.Tensor(2):uniform():cmul(torch.Tensor({rect_max - rect_min, rect_max - rect_min})):int()
|
|
|
+ local rect_h = r[1] + rect_min
|
|
|
+ local rect_w = r[2] + rect_min
|
|
|
+ return rect_h, rect_w
|
|
|
+end
|
|
|
+function random_rect(height, width, rect_h, rect_w)
|
|
|
+ local r = torch.Tensor(2):uniform():cmul(torch.Tensor({height - 1 - rect_h, width-1 - rect_w})):int()
|
|
|
+ local rect_y1 = r[1] + 1
|
|
|
+ local rect_x1 = r[2] + 1
|
|
|
+ local rect_x2 = rect_x1 + rect_w
|
|
|
+ local rect_y2 = rect_y1 + rect_h
|
|
|
+ return {x1 = rect_x1, y1 = rect_y1, x2 = rect_x2, y2 = rect_y2}
|
|
|
+end
|
|
|
function data_augmentation.erase(src, p, n, rect_min, rect_max)
|
|
|
if torch.uniform() < p then
|
|
|
local src, conversion = iproc.byte2float(src)
|
|
@@ -21,17 +36,12 @@ function data_augmentation.erase(src, p, n, rect_min, rect_max)
|
|
|
local height = src:size(2)
|
|
|
local width = src:size(3)
|
|
|
for i = 1, n do
|
|
|
- local r = torch.Tensor(4):uniform():cmul(torch.Tensor({height-1, width-1, rect_max - rect_min, rect_max - rect_min})):int()
|
|
|
- local rect_y1 = r[1] + 1
|
|
|
- local rect_x1 = r[2] + 1
|
|
|
- local rect_h = r[3] + rect_min
|
|
|
- local rect_w = r[4] + rect_min
|
|
|
- local rect_x2 = math.min(rect_x1 + rect_w, width)
|
|
|
- local rect_y2 = math.min(rect_y1 + rect_h, height)
|
|
|
- local sub_rect = src:sub(1, ch, rect_y1, rect_y2, rect_x1, rect_x2)
|
|
|
- for i = 1, ch do
|
|
|
- sub_rect[i]:fill(src[i][rect_y1][rect_x1])
|
|
|
- end
|
|
|
+ local rect_h, rect_w = random_rect_size(rect_min, rect_max)
|
|
|
+ local rect1 = random_rect(height, width, rect_h, rect_w)
|
|
|
+ local rect2 = random_rect(height, width, rect_h, rect_w)
|
|
|
+ dest_rect = src:sub(1, ch, rect1.y1, rect1.y2, rect1.x1, rect1.x2)
|
|
|
+ src_rect = src:sub(1, ch, rect2.y1, rect2.y2, rect2.x1, rect2.x2)
|
|
|
+ dest_rect:copy(src_rect:clone())
|
|
|
end
|
|
|
if conversion then
|
|
|
src = iproc.float2byte(src)
|