-
Notifications
You must be signed in to change notification settings - Fork 0
/
augment.py
148 lines (114 loc) · 5.25 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Module to augment a set of labeled images, with bounding box transformations.
Modify get_augmentations() and multi_aug() in userdefs.py as necessary for your specific
deployment needs with the albumentations library.
"""
import os
from collections import Counter
import cv2
import numpy as np
import albumentations as alb
from tqdm import tqdm
from retrain.utils import get_label_path
from userdefs import multi_aug, get_augmentations
class Augmenter:
"""Wrapper class for an ImageFolder object to begin augmenting images."""
def __init__(self, img_folder):
self.img_folder = img_folder
def get_incr_factors(self, imgs_per_class):
"""Get a dictionary of images in the image folder and the number of times each
item should be agumented."""
desired = {i: imgs_per_class for i in range(self.img_folder.num_classes)}
img_dict = self.img_folder.make_img_dict()
incr_factors = {img: 0 for img in img_dict.keys()}
imgs_by_label_count = dict(
sorted(img_dict.items(), key=lambda x: len(x[1]), reverse=True,)
)
# This algorithm could be optimized by sorting by the labels
# by the most desired and normalizing classes relative to each other
while sum(desired.values()) > 0:
for img, labels in imgs_by_label_count.items():
label_counts = Counter(labels)
if any(desired[label] < count for label, count in label_counts.items()):
continue
for label, count in label_counts.items():
desired[label] -= count
incr_factors[img] += 1
return incr_factors
def augment(self, imgs_per_class, major_aug, min_visibility=0.75):
"""Augment all images in the image folder, adding the augmentations to the folder.
Parameters:
imgs_per_class Target number of images. If there are more samples in the folder
than this number for a class, no augmentation will be performed.
major_aug A boolean variable determining if 'major' transformations will be used.
min_visibility Minimum visibility of the resultant bounding boxes after augmentation.
This is a value in (0.0, 1.0] relative to the area of the bounding box.
"""
incr_factors = self.get_incr_factors(imgs_per_class)
bbox_params = alb.BboxParams(
"yolo", min_visibility=min_visibility, label_fields=["classes"]
)
aug = multi_aug(get_augmentations(), major_aug, bbox_params)
pbar = tqdm(desc="Augmenting training images", total=sum(incr_factors.values()))
for img, count in incr_factors.items():
augment_img(aug, "compose", img, count=count)
new_imgs = {
f"{img[:-4].replace('images', 'aug-images')}_compose-{i}.png"
for i in range(count)
}
self.img_folder.imgs.update(new_imgs)
self.img_folder.labels.update({get_label_path(img) for img in new_imgs})
pbar.update(count)
pbar.close()
def augment_img(aug, suffix, img_path, count=1):
"""Iteratively augment a single image with a given augmentation function.
Generates transformed labels for each augmentation
Parameters:
aug Augmentation function from albumentations
suffix String suffix appended to each augmentation file and label
img_path Filesystem path to the image to augment
count Number of augmentations to perform
"""
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
base_name = img_path[:-4]
label_path = get_label_path(img_path)
boxes, field_ids = parse_label(label_path)
i = 0
while i < count:
aug_path = f"{base_name.replace('images', 'aug-images')}_{suffix}-{i}.png"
new_txt_path = get_label_path(aug_path)
os.makedirs(os.path.dirname(new_txt_path), exist_ok=True)
os.makedirs(os.path.dirname(aug_path), exist_ok=True)
if os.path.exists(aug_path) and os.path.exists(new_txt_path):
i += 1
continue
try:
result = aug(image=img, bboxes=boxes, classes=field_ids)
except IndexError:
continue
aug_img = result["image"]
new_bboxes = [" ".join(map(str, bbox)) for bbox in result["bboxes"]]
# Check if bounding boxes have been removed due to visibility criteria
if len(new_bboxes) != len(boxes):
continue
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(aug_path, aug_img)
with open(new_txt_path, "w+") as out:
for box_i, bbox_str in enumerate(new_bboxes):
out.write(f"{field_ids[box_i]} {bbox_str}\n")
i += 1
def parse_label(label_path):
"""Parse a Darknet format label, returning bounding boxes and class IDs."""
labels = open(label_path, "r").read().split("\n")
boxes = list()
field_ids = list()
for label in labels:
box = list()
for i, info in enumerate(label.split(" ")):
if i == 0:
field_ids.append(int(info))
else:
box.append(np.float64(info))
boxes.append(box)
return boxes, field_ids