Skip to content

Commit

Permalink
test: implementing unit tests for correlation plot (#77)
Browse files Browse the repository at this point in the history
* test: implementing unit tests for correlation plot

* feat: adding validation to correlation plots

* feat: adding unit test for no feature on dataframe

---------

Co-authored-by: Franklin Fernandez <franklincf@tecgraf.puc-rio.br>
Co-authored-by: cristian.munoz <cristian.munoz@holisticai.com>
  • Loading branch information
3 people authored Sep 20, 2023
1 parent 11eba44 commit 4bef184
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 7 deletions.
16 changes: 9 additions & 7 deletions holisticai/bias/plots/_bias_exploratory_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

# utils
from ...utils import get_colors
from ...utils._validation import _regression_checks
from ...utils._validation import (
_check_columns,
_check_numerical_dataframe,
_regression_checks,
)
from ..metrics import confusion_matrix


Expand Down Expand Up @@ -266,6 +270,8 @@ def correlation_matrix_plot(
"""Plot the correlation matrix of a given dataframe with respect to
a given target and a certain number of features.
Obs. The dataframe must contain only numerical features.
Parameters
----------
df : (DataFrame)
Expand All @@ -288,12 +294,8 @@ def correlation_matrix_plot(
matplotlib ax
"""
"""Prints the correlation matrix """
try:
df = df.astype(int)
except:
raise TypeError(
"Dataframe 'df' cannot be converted to int. All the values must be numerical."
)
df = _check_numerical_dataframe(df)
_check_columns(df, target_feature)

sns.set(font_scale=1.25)
if ax is None:
Expand Down
47 changes: 47 additions & 0 deletions holisticai/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,50 @@ def _check_valid_y_proba(y_proba: np.ndarray):
assert (
correct_proba_values
), f"""probability values must be in the interval [0,1], found: {y_proba}"""


def _check_numerical_dataframe(df: pd.DataFrame):
"""
Check numerical DataFrame
Description
----------
This function checks if a dataframe is numerical
or can be converted to numerical.
Parameters
----------
df : pandas DataFrame
input
Returns
-------
ValueError or converted DataFrame
"""
try:
return df.astype(float)
except ValueError:
raise ValueError("DataFrame cannot be converted to numerical values")


def _check_columns(df: pd.DataFrame, column: str):
"""
Check columns
Description
----------
This function checks if a column exists in a dataframe.
Parameters
----------
df : pandas DataFrame
input
column : str
column name
Returns
-------
ValueError or None
"""
if column not in df.columns:
raise ValueError(f"Column '{column}' does not exist in DataFrame")
57 changes: 57 additions & 0 deletions tests/bias/plots/test_all_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest

from holisticai.bias.metrics import classification_bias_metrics, regression_bias_metrics

Expand All @@ -15,6 +16,7 @@
abroca_plot,
accuracy_bar_plot,
bias_metrics_report,
correlation_matrix_plot,
disparate_impact_curve,
disparate_impact_plot,
distribution_plot,
Expand Down Expand Up @@ -98,6 +100,61 @@ def test_histogram_plot(monkeypatch):
assert True


@pytest.mark.xfail(raises=ValueError)
def test_correlation_plot_non_numerical_data(monkeypatch):
"""test_correlation_plot: This test should fail because the data is not numerical"""
monkeypatch.setattr(plt, "show", lambda: None)
_, ax = plt.subplots()
correlation_matrix_plot(
df,
target_feature="class",
n_features=10,
cmap="YlGnBu",
ax=ax,
size=None,
title=None,
)


def test_correlation_plot_numerical_data(monkeypatch):
"""test_correlation_plot"""
monkeypatch.setattr(plt, "show", lambda: None)
_, ax = plt.subplots()
# ensure dataframes are numerical
df_ = df.copy()
df_clean = df_.iloc[
:, [i for i, n in enumerate(df_.isna().sum(axis=0).T.values) if n < 100]
]
df_clean.drop(
columns=["sex", "race", "education", "marital-status", "relationship"],
inplace=True,
)
df_clean["class"].replace({">50K": 1, "<=50K": 0}, inplace=True)
correlation_matrix_plot(
df_clean,
target_feature="class",
n_features=5,
cmap="YlGnBu",
ax=ax,
size=None,
title=None,
)
assert True


@pytest.mark.xfail(raises=ValueError)
def test_correlation_plot_numerical_data_no_feature(monkeypatch):
"""test_correlation_plot: This test should fail because the feature is not in the dataframe"""
from sklearn.datasets import load_diabetes

dataset = load_diabetes() # numerical dataset
X = dataset.data
feature_names = dataset.feature_names
X = pd.DataFrame(X, columns=feature_names)

correlation_matrix_plot(X, target_feature="ages", size=(12, 7))


def test_frequency_plot(monkeypatch):
"""test_frequency_plot"""
monkeypatch.setattr(plt, "show", lambda: None)
Expand Down

0 comments on commit 4bef184

Please sign in to comment.