Skip to content

Add plot_pca_explained_variance function to preprocess module #85

@idanmoradarthas

Description

@idanmoradarthas

Description

Enhance the preprocess module by adding a new plotting function to visualize the cumulative explained variance ratio from PCA. This will help users quickly determine how many principal components are required to capture a desired proportion of variance.

The new function should:

  • Accept a numerical pandas DataFrame X (samples as rows, features as columns).
  • Include a boolean flag use_scaling (default: True) to optionally scale the data before PCA.
  • Accept a custom scaler object (default: StandardScaler()).
  • Accept an optional ax (matplotlib Axes); create a new figure if None.
  • Accept additional *args and **kwargs to pass directly to sklearn.decomposition.PCA.
  • Plot horizontal reference lines at 70% and 80% explained variance.
  • Legend location should be a parameter (default: loc="lower right").
  • Return the matplotlib Axes object, consistent with other plotting functions in preprocess.py.

Proposed implementation

Add the following code to preprocess.py:

from typing import Optional
import pandas as pd
import numpy as np
from matplotlib import axes, pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.base import TransformerMixin


def plot_pca_explained_variance(
    X: pd.DataFrame,
    use_scaling: bool = True,
    scaler: TransformerMixin = StandardScaler(),
    ax: Optional[axes.Axes] = None,
    *args,
    **kwargs,
) -> axes.Axes:
    """Plot the cumulative explained variance ratio of PCA components.

    This visualization helps determine how many principal components are needed
    to capture a desired proportion of the total variance in the data.
    Horizontal reference lines are drawn at 70% and 80% variance.

    Parameters
    ----------
    X : pd.DataFrame
        Input data with numerical features (rows = samples, columns = features).
    use_scaling : bool, default=True
        If True, scale the data using the provided scaler before fitting PCA.
    scaler : TransformerMixin, default=StandardScaler()
        Scaler instance to use when use_scaling is True.
    ax : Optional[axes.Axes], default=None
        Matplotlib Axes to draw the plot on. If None, a new figure and Axes are created.
    *args, **kwargs
        Additional arguments passed directly to sklearn.decomposition.PCA.

    Returns
    -------
    axes.Axes
        The Axes object containing the plot.

    Raises
    ------
    ValueError
        If any column in X is non-numeric.
    """
    if ax is None:
        _, ax = plt.subplots(figsize=(8, 5))

    if not np.all(X.dtypes.apply(pd.api.types.is_numeric_dtype)):
        raise ValueError("All columns in X must be numeric.")

    X_array = X.to_numpy()

    if use_scaling:
        X_array = scaler.fit_transform(X_array)

    pca = PCA(*args)
    pca.fit(X_array)

    explained_variance_ratio = pca.explained_variance_ratio_
    cumulative_variance = np.cumsum(explained_variance_ratio)

    ax.plot(
        range(1, len(cumulative_variance) + 1),
        cumulative_variance,
        marker="o",
        linestyle="-",
        color="b",
        label="Cumulative explained variance",
        **kwargs
    )

    # Reference lines for common variance thresholds
    ax.axhline(0.70, color="gray", linestyle="--", linewidth=1, label="70% variance")
    ax.axhline(0.80, color="gray", linestyle="--", linewidth=1, label="80% variance")

    ax.set_xlabel("Number of Principal Components")
    ax.set_ylabel("Cumulative Explained Variance Ratio")
    ax.set_title("PCA - Cumulative Explained Variance")
    ax.grid(True)
    ax.legend()

    return ax

Additional tasks after implementation

  • Add an example usage to the ReadTheDocs documentation for the preprocess module.
  • Write unit tests (e.g., checking output with/without scaling and verifying Axes return).

This addition will make dimensionality analysis more accessible and consistent with the existing visualization utilities in the package.

Here is an example output:
Image

Metadata

Metadata

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions