Skip to content
This repository was archived by the owner on Oct 4, 2024. It is now read-only.

Commit e48af8a

Browse files
committed
Create augmentor.py
augment the dataset by random filling cracks
1 parent 42f8111 commit e48af8a

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

tinytools/augmentor.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import shutil
2+
import cv2
3+
import os
4+
import random
5+
6+
import numpy as np
7+
8+
class Augmentor:
9+
def __init__(self, path, aug_path):
10+
self.path = path
11+
self.aug_path = aug_path
12+
self.images_path = os.path.join(path, "images")
13+
self.masks_path = os.path.join(path, "masks")
14+
self.aug_images_path = os.path.join(aug_path, "images")
15+
self.aug_masks_path = os.path.join(aug_path, "masks")
16+
self.augmented_count = 0
17+
18+
def check_dir(self, directory):
19+
if os.path.exists(directory):
20+
shutil.rmtree(directory)
21+
if not os.path.exists(directory):
22+
os.makedirs(directory)
23+
24+
def count_cracks(self, image_path):
25+
# Read the image
26+
image = cv2.imread(image_path, 0) # Read as grayscale
27+
28+
# Convert the image to binary
29+
_, binary_image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY)
30+
31+
# Find contours in the binary image
32+
contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
33+
34+
# Count the number of cracks
35+
num_cracks = len(contours)
36+
37+
return num_cracks, contours
38+
39+
def calculate_channel_median(self, image_path):
40+
# Read the image
41+
image = cv2.imread(image_path)
42+
43+
# Split the image into RGB channels
44+
b, g, r = cv2.split(image)
45+
46+
# Calculate the median value for each channel
47+
median_r = np.median(r)
48+
median_g = np.median(g)
49+
median_b = np.median(b)
50+
51+
return median_r, median_g, median_b
52+
53+
def augment_images(self):
54+
self.check_dir(self.aug_path)
55+
self.check_dir(self.aug_images_path)
56+
self.check_dir(self.aug_masks_path)
57+
58+
for filename in os.listdir(self.images_path):
59+
# Check if there is a corresponding mask file
60+
mask_filename = os.path.join(self.masks_path, filename)
61+
if not os.path.isfile(mask_filename):
62+
continue
63+
64+
# Check if the image is all black
65+
image = cv2.imread(os.path.join(self.images_path, filename))
66+
mask = cv2.imread(mask_filename, 0)
67+
if cv2.countNonZero(mask) == 0:
68+
continue
69+
70+
# Calculate the number of cracks
71+
num_cracks, contours = self.count_cracks(mask_filename)
72+
73+
if num_cracks == 1:
74+
continue
75+
76+
# Calculate the median value for each channel
77+
median = self.calculate_channel_median(os.path.join(self.images_path, filename))
78+
79+
# Randomly select a subset of cracks to erase
80+
num_to_erase = int(num_cracks / 2)
81+
cracks_to_erase = random.sample(range(num_cracks), num_to_erase)
82+
83+
# Load the mask image
84+
mask = cv2.imread(mask_filename, 0)
85+
86+
# Erase the selected cracks in the mask
87+
for i in cracks_to_erase:
88+
background = np.zeros_like(image)
89+
background = cv2.cvtColor(background, cv2.COLOR_BGR2GRAY)
90+
cv2.drawContours(background, [contours[i]], -1, 255, thickness=cv2.FILLED)
91+
92+
# Dilate the cracks in the background image
93+
background = cv2.dilate(background, np.ones((5, 5), np.uint8), iterations=1)
94+
_, background = cv2.threshold(background, 1, 255, cv2.THRESH_BINARY)
95+
96+
# Fill the selected cracks in the image
97+
filled_image = image.copy()
98+
filled_image[background == 255] = median
99+
filled_image = cv2.inpaint(filled_image, background, 3, cv2.INPAINT_TELEA)
100+
101+
# Replace the original image with the filled image
102+
image = filled_image
103+
cv2.drawContours(mask, [contours[i]], -1, 0, thickness=cv2.FILLED)
104+
105+
# Save the modified image and mask in the augmented directory
106+
cv2.imwrite(os.path.join(self.aug_images_path, filename), image)
107+
cv2.imwrite(os.path.join(self.aug_masks_path, filename), mask)
108+
109+
self.augmented_count += 1 # Increment the count of augmented images
110+
111+
print("Total number of augmented images:", self.augmented_count)
112+
113+
def main():
114+
path = "input" # 替换为你的路径
115+
aug_path = "augmented"
116+
117+
augmentor = Augmentor(path, aug_path)
118+
augmentor.augment_images()
119+
120+
if __name__ == "__main__":
121+
main()

0 commit comments

Comments
 (0)