Skip to content

Commit e083ce3

Browse files
committed
PersonDatasetAssembler: fix raw extensions being case-sensetive
1 parent cd1e275 commit e083ce3

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

PersonDatasetAssembler/PersonDatasetAssembler.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import cv2
2626
import numpy
2727
from tqdm import tqdm
28-
from wand.exceptions import BlobError
28+
from wand.exceptions import BlobError, CoderError
2929
from wand.image import Image
3030

3131
image_ext_ocv = [".bmp", ".jpeg", ".jpg", ".png"]
@@ -41,7 +41,7 @@ def find_image_files(path: str) -> list[str]:
4141
for root, dirs, files in os.walk(path):
4242
for filename in files:
4343
name, extension = os.path.splitext(filename)
44-
if extension.lower() in image_ext_ocv or extension in image_ext_wand:
44+
if extension.lower() in image_ext_ocv or extension.lower() in image_ext_wand:
4545
paths.append(os.path.join(root, filename))
4646
return paths
4747

@@ -58,10 +58,20 @@ def image_loader(paths: list[str]) -> Iterator[numpy.ndarray]:
5858
yield image
5959
elif extension in image_ext_wand:
6060
try:
61-
image = Image(filename=path)
61+
wandImage = Image(filename=path)
62+
wandImage.auto_orient()
63+
image = numpy.array(wandImage)
64+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
65+
yield image
6266
except BlobError as e:
6367
print(f"Warning: could not load {path}, {e}")
6468
continue
69+
except CoderError as e:
70+
print(f"Warning: failure in wand while loading {path}, {e}")
71+
72+
else:
73+
print(f"Warning: could not load {path}, {e}")
74+
continue
6575

6676

6777
def extract_video_images(video: cv2.VideoCapture, interval: int = 0):
@@ -132,7 +142,7 @@ def process_referance(detector: cv2.FaceDetectorYN, recognizer: cv2.FaceRecogniz
132142

133143
recognizer = cv2.FaceRecognizerSF.create(model=args.match_model, config="", backend_id=cv2.dnn.DNN_BACKEND_DEFAULT , target_id=cv2.dnn.DNN_TARGET_CPU)
134144
detector = cv2.FaceDetectorYN.create(model=args.detect_model, config="", input_size=[320, 320],
135-
score_threshold=0.6, nms_threshold=0.3, top_k=5000, backend_id=cv2.dnn.DNN_BACKEND_DEFAULT, target_id=cv2.dnn.DNN_TARGET_CPU)
145+
score_threshold=0.4, nms_threshold=0.2, top_k=5000, backend_id=cv2.dnn.DNN_BACKEND_DEFAULT, target_id=cv2.dnn.DNN_TARGET_CPU)
136146

137147
referance_features = process_referance(detector, recognizer, args.referance)
138148
if len(referance_features) < 1:
@@ -166,7 +176,7 @@ def process_referance(detector: cv2.FaceDetectorYN, recognizer: cv2.FaceRecogniz
166176
resized = image
167177
score, match = contains_face_match(detector, recognizer, resized, referance_features, args.threshold)
168178
if match and not args.invert or not match and args.invert:
169-
filename = f"{counter:04}.png"
179+
filename = f"{counter:04}.jpg"
170180
cv2.imwrite(os.path.join(args.out, filename), image)
171181
counter += 1
172182
progress.set_description(f"{score:1.2f}")

0 commit comments

Comments
 (0)