Skip to content

Commit

Permalink
explainability ease added. (#113)
Browse files Browse the repository at this point in the history
* feat: core metrics

* position parity added

* Delete holisticai/explainability/metrics/core/contrast_metrics.py

* position parity added with pre-commit

* deleted contrast metrics

* returns added to the description

* multiclass testing added

* multiclass testing added with pre-commit.

* vscode deleted

* position parity changed to get fi dataframes, not indexes

* bug fixed

* rank alignment metric added.

* changed

* fixed.

* explainability ease added.

* fixed.

* fixed.

---------

Co-authored-by: crismunoz <cristian.munoz@holisticai.com>
  • Loading branch information
aminatkhamokova and crismunoz authored Apr 17, 2024
1 parent e1c4f26 commit 743cadd
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
33 changes: 33 additions & 0 deletions holisticai/explainability/metrics/core/all_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
importance_order_constrast,
importance_range_constrast,
)
from holisticai.explainability.metrics.global_importance._explainability_level import (
compute_explainability_ease_score,
)


def position_parity(
Expand Down Expand Up @@ -63,3 +66,33 @@ def rank_alignment(
for i in conditional_feature_importance
]
)


def explainability_ease(partial_dependence_list: list[dict]):
"""
Parameters
----------
partial_dependence_list: list[dict]
a list of dictionaries containing partial dependencies for each feature.
For multiclass classification, partial dependencies are computed for each class separately.
For binary classification, partial dependence is calculated only for the positive class,
resulting in a single dictionary in the list. Similarly, for regression, there's only one dictionary.
Returns
-------
float
explainability ease value, average explainability ease value for multiclass setting
"""
if len(partial_dependence_list) == 1:
return compute_explainability_ease_score(
partial_dependence=partial_dependence_list[0]
)[0]
else:
return np.mean(
[
compute_explainability_ease_score(
partial_dependence=partial_dependence
)[0]
for partial_dependence in partial_dependence_list
]
)
46 changes: 46 additions & 0 deletions tests/explainability/test_all_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

from holisticai.datasets import load_dataset
from holisticai.explainability.metrics.core.all_metrics import (
explainability_ease,
position_parity,
rank_alignment,
)
from holisticai.explainability.metrics.global_importance._explainability_level import (
compute_partial_dependence,
)
from holisticai.explainability.metrics.utils import get_index_groups


Expand Down Expand Up @@ -197,3 +201,45 @@ def test_multiclass_classification_rank_alignment():
for _, index in index_groups.items()
]
assert rank_alignment(feat_importance, conditional_feature_importance) is not None


def test_binary_classification_explainability_ease():
X_train, X_test, y_train, _, _ = binary_classification_process_dataset()
model = train_model_classification(X_train, y_train)
pred = model.predict(X_test)
feat_importance = get_feat_importance(
x=X_test, y=pred, model=model, samples_len=len(X_test)
)
partial_dependence = compute_partial_dependence(
model=model, feature_importance=feat_importance, x=X_test, target=1
)
assert explainability_ease(partial_dependence_list=[partial_dependence]) is not None


def test_regression_explainability_ease():
X_train, X_test, y_train, _, _ = regression_process_dataset()
model = train_model_regression(X_train, y_train)
pred = model.predict(X_test)
feat_importance = get_feat_importance(
x=X_test, y=pred, model=model, samples_len=len(X_test)
)
partial_dependence = compute_partial_dependence(
model=model, feature_importance=feat_importance, x=X_test, target=None
)
assert explainability_ease(partial_dependence_list=[partial_dependence]) is not None


def test_multiclass_classification_explainability_ease():
X_train, X_test, y_train, _ = multiclass_classification_process_dataset()
model = train_model_classification(X_train, y_train)
pred = model.predict(X_test)
feat_importance = get_feat_importance(
x=X_test, y=pred, model=model, samples_len=len(X_test)
)
partial_dependence = [
compute_partial_dependence(
model=model, feature_importance=feat_importance, x=X_test, target=i
)
for i in np.unique(pred)
]
assert explainability_ease(partial_dependence_list=partial_dependence) is not None

0 comments on commit 743cadd

Please sign in to comment.