Skip to content

Commit 04e8b28

Browse files
committed
Add error handling and input validation to calculate_kmeans function
1 parent 09e2a79 commit 04e8b28

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

tasnif/calculations.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,17 @@ def calculate_kmeans(pca_embeddings, num_classes):
4141
labels and centroids.
4242
"""
4343
print("KMeans processing...")
44-
centroid, labels = kmeans2(data=pca_embeddings, k=num_classes, minit="points")
45-
counts = np.bincount(labels)
46-
print("Kmeans done!")
47-
return centroid, labels, counts
44+
if not isinstance(pca_embeddings, np.ndarray):
45+
raise ValueError("pca_embeddings must be a numpy array")
46+
47+
if num_classes > len(pca_embeddings):
48+
raise ValueError(
49+
"num_classes must be less than or equal to the number of samples in pca_embeddings"
50+
)
51+
52+
try:
53+
centroid, labels = kmeans2(data=pca_embeddings, k=num_classes, minit="points")
54+
counts = np.bincount(labels)
55+
return centroid, labels, counts
56+
except Exception as e:
57+
raise RuntimeError(f"An error occurred during KMeans processing: {e}")

0 commit comments

Comments
 (0)