Skip to content

Feature: visualize_clusters Method #66

@idanmoradarthas

Description

@idanmoradarthas

✨ Feature Request: visualize_clusters / _plot_cluster_scatter Method

Summary

Add a new utility method called visualize_clusters (alias: _plot_cluster_scatter) to the unsupervised module of the DataScienceUtils package. This function should help users visually evaluate the clustering output of unsupervised models by plotting the clustered data on a 2D scatter plot. It should optionally support displaying cluster centroids if provided.


Motivation

Unsupervised learning outputs such as k-means clustering or DBSCAN group data into clusters, but visualizing the results is often a crucial step in model evaluation, especially when working with 2D or PCA/t-SNE/UMAP-reduced data. Providing this feature will:

  • Help users better interpret clustering results.
  • Reduce boilerplate code in notebooks.
  • Improve the overall UX of the package by integrating visualization tools into the modeling workflow.

API Proposal

def _plot_cluster_scatter(ax, X, cluster_labels, cluster_centers, colors, unique_labels):
    """Create scatter plot of clusters (helper function)."""
    from sklearn.decomposition import PCA
    
    # Use first two dimensions or apply PCA if more than 2D
    if X.shape[1] == 2:
        X_plot = X
        centers_plot = cluster_centers if cluster_centers is not None else None
    else:
        # Apply PCA for visualization
        pca = PCA(n_components=2)
        X_plot = pca.fit_transform(X)
        
        if cluster_centers is not None:
            centers_plot = pca.transform(cluster_centers)
        else:
            centers_plot = None
        
        ax.text(0.02, 0.98, f'PCA Explained Variance: {pca.explained_variance_ratio_.sum():.1%}',
               transform=ax.transAxes, verticalalignment='top', fontsize=9,
               bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    
    # Plot data points
    for i, cluster_label in enumerate(unique_labels):
        cluster_data = X_plot[cluster_labels == cluster_label]
        ax.scatter(cluster_data[:, 0], cluster_data[:, 1],
                  c=[colors[i]], label=f'Cluster {cluster_label}',
                  alpha=0.6, s=50)
    
    # Plot cluster centers if available
    if centers_plot is not None:
        ax.scatter(centers_plot[:, 0], centers_plot[:, 1],
                  c='red', marker='x', s=200, linewidths=3,
                  label='Centroids')
    
    ax.set_xlabel('Feature 1' if X.shape[1] == 2 else 'PC1', fontsize=12)
    ax.set_ylabel('Feature 2' if X.shape[1] == 2 else 'PC2', fontsize=12)
    ax.set_title('Cluster Visualization', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)

To Do

  • Create visualize_clusters method in unsupervised.py.
  • Implement core functionality using matplotlib.pyplot.scatter and optionally plt.scatter for centroids.
  • Ensure color differentiation between clusters using a color palette.
  • Handle missing or incorrect column inputs gracefully with ValueError.
  • Support optional centroids overlay with matching cluster labels.
  • Return the matplotlib.Axes object.
  • Write unit tests under tests/test_unsupervised.py to validate:
    * Basic cluster plotting.
    * Handling of missing columns.
    * Correct plotting of centroids.
  • Update the documentation and examples.
  • Optionally add support for interactive mode (e.g., via Plotly) in the future.

Metadata

Metadata

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions