-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Labels
enhancementNew feature or requestNew feature or request
Description
✨ Feature Request: visualize_clusters / _plot_cluster_scatter Method
Summary
Add a new utility method called visualize_clusters (alias: _plot_cluster_scatter) to the unsupervised module of the DataScienceUtils package. This function should help users visually evaluate the clustering output of unsupervised models by plotting the clustered data on a 2D scatter plot. It should optionally support displaying cluster centroids if provided.
Motivation
Unsupervised learning outputs such as k-means clustering or DBSCAN group data into clusters, but visualizing the results is often a crucial step in model evaluation, especially when working with 2D or PCA/t-SNE/UMAP-reduced data. Providing this feature will:
- Help users better interpret clustering results.
- Reduce boilerplate code in notebooks.
- Improve the overall UX of the package by integrating visualization tools into the modeling workflow.
API Proposal
def _plot_cluster_scatter(ax, X, cluster_labels, cluster_centers, colors, unique_labels):
"""Create scatter plot of clusters (helper function)."""
from sklearn.decomposition import PCA
# Use first two dimensions or apply PCA if more than 2D
if X.shape[1] == 2:
X_plot = X
centers_plot = cluster_centers if cluster_centers is not None else None
else:
# Apply PCA for visualization
pca = PCA(n_components=2)
X_plot = pca.fit_transform(X)
if cluster_centers is not None:
centers_plot = pca.transform(cluster_centers)
else:
centers_plot = None
ax.text(0.02, 0.98, f'PCA Explained Variance: {pca.explained_variance_ratio_.sum():.1%}',
transform=ax.transAxes, verticalalignment='top', fontsize=9,
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
# Plot data points
for i, cluster_label in enumerate(unique_labels):
cluster_data = X_plot[cluster_labels == cluster_label]
ax.scatter(cluster_data[:, 0], cluster_data[:, 1],
c=[colors[i]], label=f'Cluster {cluster_label}',
alpha=0.6, s=50)
# Plot cluster centers if available
if centers_plot is not None:
ax.scatter(centers_plot[:, 0], centers_plot[:, 1],
c='red', marker='x', s=200, linewidths=3,
label='Centroids')
ax.set_xlabel('Feature 1' if X.shape[1] == 2 else 'PC1', fontsize=12)
ax.set_ylabel('Feature 2' if X.shape[1] == 2 else 'PC2', fontsize=12)
ax.set_title('Cluster Visualization', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)To Do
- Create
visualize_clustersmethod inunsupervised.py. - Implement core functionality using
matplotlib.pyplot.scatterand optionallyplt.scatterfor centroids. - Ensure color differentiation between clusters using a color palette.
- Handle missing or incorrect column inputs gracefully with
ValueError. - Support optional centroids overlay with matching cluster labels.
- Return the matplotlib.Axes object.
- Write unit tests under
tests/test_unsupervised.pyto validate:
* Basic cluster plotting.
* Handling of missing columns.
* Correct plotting of centroids. - Update the documentation and examples.
- Optionally add support for interactive mode (e.g., via Plotly) in the future.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request