Skip to content

Commit

Permalink
Merge pull request #63 from salesforce/bias
Browse files Browse the repository at this point in the history
Bias and bug fix
  • Loading branch information
yangwenz authored Jan 3, 2023
2 parents b1ce148 + 355264c commit cc65dab
Show file tree
Hide file tree
Showing 16 changed files with 551 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ We will continue improving this library to make it more comprehensive in the fut
| Feature analysis | NA | Global || | | | |
| Feature selection | NA | Global || | | | |
| Prediction metrics | Black box | Global | |||||
| Bias metrics | Black box | Global | || | | |
| Partial dependence plots | Black box | Global | || | | |
| Accumulated local effects | Black box | Global | || | | |
| Sensitivity analysis | Black box | Global | || | | |
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Method Model Type Explanation Type EDA
Feature analysis NA Global ✓
Feature selection NA Global ✓
Prediction metrics Black box Global ✓ ✓ ✓ ✓
Bias metrics Black box Global ✓
PDP Black box Global ✓
ALE Black box Global ✓
Sensitivity analysis Black box Global ✓
Expand Down
11 changes: 11 additions & 0 deletions docs/omnixai.explainers.tabular.agnostic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ omnixai.explainers.tabular.agnostic package
lime
shap
pdp
ale
sensitivity
L2X.l2x
permutation
bias

omnixai.explainers.tabular.agnostic.lime module
-----------------------------------------------
Expand Down Expand Up @@ -76,3 +79,11 @@ omnixai.explainers.tabular.agnostic.shap_global module
:members:
:undoc-members:
:show-inheritance:

omnixai.explainers.tabular.agnostic.bias module
-----------------------------------------------

.. automodule:: omnixai.explainers.tabular.agnostic.bias
:members:
:undoc-members:
:show-inheritance:
3 changes: 0 additions & 3 deletions omnixai/explainers/nlp/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def __init__(
params=params,
)

def _convert_data(self, X):
return Text(X)

@staticmethod
def list_explainers():
"""
Expand Down
10 changes: 9 additions & 1 deletion omnixai/explainers/prediction/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
preprocess: Callable = None,
postprocess: Callable = None,
predict_function: Callable = None,
**kwargs
):
"""
:param mode: The task type, e.g., `classification` and `regression`.
Expand Down Expand Up @@ -110,10 +111,17 @@ def __init__(

self.mode = mode
self.y_test = test_targets.astype(int) if mode == "classification" else test_targets
self.y_prob = self.predict_function(test_data)
self.y_prob = self._predict(test_data, batch_size=kwargs.get("batch_size", 128))
if mode == "classification":
self.num_classes = self.y_prob.shape[1]

def _predict(self, x, batch_size=128):
n, predictions = x.shape[0], []
for i in range(0, n, batch_size):
predictions.append(self.predict_function(x[i: i + batch_size]))
z = np.concatenate(predictions, axis=0)
return z.flatten() if self.mode == "regression" else z

def _roc(self, **kwargs) -> ROCExplanation:
"""
Computes the micro-average ROC curve, macro-average ROC curve and ROC curves of all the classes.
Expand Down
2 changes: 2 additions & 0 deletions omnixai/explainers/tabular/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .agnostic.L2X.l2x import L2XTabular
from .agnostic.permutation import PermutationImportance
from .agnostic.shap_global import GlobalShapTabular
from .agnostic.bias import BiasAnalyzer
from .counterfactual.mace.mace import MACEExplainer
from .counterfactual.ce import CounterfactualExplainer
from .counterfactual.knn import KNNCounterfactualExplainer
Expand All @@ -34,6 +35,7 @@
"L2XTabular",
"PermutationImportance",
"GlobalShapTabular",
"BiasAnalyzer",
"MACEExplainer",
"CounterfactualExplainer",
"KNNCounterfactualExplainer",
Expand Down
305 changes: 305 additions & 0 deletions omnixai/explainers/tabular/agnostic/bias.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion omnixai/explainers/tabular/agnostic/pdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, training_data: Tabular, predict_function, mode="classificatio
candidates = sorted(np.unique(self.data[:, column_index]))
else:
# Continuous-valued features
percentiles = np.linspace(1, 99, num=grid_resolution)
percentiles = np.linspace(0, 100, num=grid_resolution)
candidates = sorted(set(np.percentile(self.data[:, column_index], percentiles)))
self.candidates[column_index] = candidates

Expand Down
2 changes: 2 additions & 0 deletions omnixai/explainers/tabular/agnostic/shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
self.link = kwargs.get("link", None)
if self.link is None:
self.link = "logit" if self.mode == "classification" else "identity"
else:
del kwargs["link"]

self.ignored_features = set(ignored_features) if ignored_features is not None else set()
if self.target_column is not None:
Expand Down
5 changes: 3 additions & 2 deletions omnixai/explanations/tabular/ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _plotly_figure(self, class_names=None, **kwargs):

explanations = self.get_explanations()
features = list(explanations.keys())
num_cols = 2
num_cols = min(2, len(features))
num_rows = int(np.ceil(len(features) / num_cols))
fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=features)
for i, feature in enumerate(features):
Expand Down Expand Up @@ -129,7 +129,8 @@ def _plotly_figure(self, class_names=None, **kwargs):
line=dict(color="#808080"),
legendgroup="Target"),
row=row + 1, col=col + 1)
fig.update_layout(height=260 * num_rows)
if num_rows > 1:
fig.update_layout(height=260 * num_rows)
return fig

def plotly_plot(self, class_names=None, **kwargs):
Expand Down
134 changes: 134 additions & 0 deletions omnixai/explanations/tabular/bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
"""
The bias analysis results for tabular data.
"""
import numpy as np
from ..base import ExplanationBase, DashFigure
from collections import OrderedDict


class BiasExplanation(ExplanationBase):
"""
The class for bias analysis results. The bias analysis results are stored in a dict.
"""

def __init__(self, mode):
"""
:param mode: The task type, e.g., `classification` or `regression`.
"""
super().__init__()
self.mode = mode
self.explanations = OrderedDict()

def add(self, metric_name, metric_values):
"""
Adds a new bias metric.
:param metric_name: The bias metric name.
:param metric_values: The bias metric values.
"""
self.explanations[metric_name] = metric_values

def get_explanations(self):
"""
Gets the bias analysis results.
:return: A dict containing the bias analysis results with the following format:
`{metric_name: {"feature value or threshold": the metric value}, ...}`.
"""
return self.explanations

def _rearrange_metrics(self):
metric_names = list(self.explanations.keys())
labels = sorted(self.explanations[metric_names[0]].keys())
label_metrics = [[self.explanations[metric][label] for metric in metric_names]
for label in labels]
return metric_names, labels, label_metrics

def plot(self, **kwargs):
"""
Returns a matplotlib figure showing the bias analysis results.
:return: A matplotlib figure plotting bias analysis results.
"""
import matplotlib.pyplot as plt

figures = []
metric_names, labels, label_metrics = self._rearrange_metrics()
for i, label in enumerate(labels):
fig, axes = plt.subplots(1, 1)
metric_scores = sorted(
list(zip([f"{f} " for f in metric_names], label_metrics[i])),
key=lambda x: abs(x[1]),
)
fnames = [f for f, s in metric_scores]
scores = [s for f, s in metric_scores]
colors = ["green" if x > 0 else "red" for x in scores]
positions = np.arange(len(scores)) + 0.5

plt.sca(axes)
plt.barh(positions, scores, align="center", color=colors)
axes.yaxis.set_ticks_position("right")
plt.yticks(positions, fnames, ha="right")
plt.title(f"Label: {label}" if self.mode == "classification"
else f"Target threshold: {label}")
plt.grid()
figures.append(fig)
return figures

def _plotly_figure(self, **kwargs):
from plotly.subplots import make_subplots
import plotly.graph_objects as go

metric_names, labels, label_metrics = self._rearrange_metrics()
num_cols = min(2, len(labels))
num_rows = int(np.ceil(len(labels) / num_cols))
if self.mode == "classification":
subplot_titles = [f"Label: {label}" for label in labels]
else:
subplot_titles = [f"Target threshold: {label}" for label in labels]
fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=subplot_titles)

for i, label in enumerate(labels):
row, col = divmod(i, num_cols)
metric_scores = sorted(
list(zip(metric_names, label_metrics[i])),
key=lambda x: abs(x[1]), reverse=True
)
fnames = [f for f, s in metric_scores]
scores = [s for f, s in metric_scores]
colors = ["#008B8B" if s > 0 else "#DC143C" for s in scores]
fig.add_trace(
go.Bar(x=fnames, y=scores, marker_color=colors),
row=row + 1, col=col + 1
)
if num_rows > 1:
fig.update_layout(height=260 * num_rows)
return fig

def plotly_plot(self, **kwargs):
"""
Returns a plotly dash figure showing the bias analysis results.
:return: A plotly dash figure plotting bias analysis results.
"""
return DashFigure(self._plotly_figure(**kwargs))

def ipython_plot(self, **kwargs):
"""
Shows the bias analysis results in IPython.
"""
import plotly

plotly.offline.iplot(self._plotly_figure(**kwargs))

@classmethod
def from_dict(cls, d):
exp = BiasExplanation(mode=d["mode"])
exp.explanations = d["explanations"]
return exp
5 changes: 3 additions & 2 deletions omnixai/explanations/tabular/pdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _plotly_figure(self, class_names=None, **kwargs):

explanations = self.get_explanations()
features = list(explanations.keys())
num_cols = 2
num_cols = min(2, len(features))
num_rows = int(np.ceil(len(features) / num_cols))
fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=features)
for i, feature in enumerate(features):
Expand Down Expand Up @@ -138,7 +138,8 @@ def _plotly_figure(self, class_names=None, **kwargs):
line=dict(color="#808080"),
legendgroup="Target"),
row=row + 1, col=col + 1)
fig.update_layout(height=260 * num_rows)
if num_rows > 1:
fig.update_layout(height=260 * num_rows)
return fig

def plotly_plot(self, class_names=None, **kwargs):
Expand Down
38 changes: 38 additions & 0 deletions omnixai/tests/explainers/bias/test_bias_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
import os
import unittest
from omnixai.utils.misc import set_random_seed
from omnixai.explainers.tabular.agnostic.bias import BiasAnalyzer
from omnixai.tests.explainers.tasks import TabularClassification


class TestClassificationBias(unittest.TestCase):

def test_classification_metric(self):
set_random_seed()
base_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
task = TabularClassification(base_folder).train_adult(num_training_samples=2000)
predict_function = lambda z: task.model.predict_proba(task.transform.transform(z))

explainer = BiasAnalyzer(
mode="classification",
predict_function=predict_function,
training_data=task.test_data,
training_targets=task.test_targets
)
explanations = explainer.explain(
feature_column="Sex",
feature_value_or_threshold=["Female", ["Male"]],
label_value_or_threshold=1
)
print(explanations.get_explanations())
explanations.plotly_plot()


if __name__ == "__main__":
unittest.main()
37 changes: 37 additions & 0 deletions omnixai/tests/explainers/bias/test_bias_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
import unittest
import numpy as np
from omnixai.utils.misc import set_random_seed
from omnixai.explainers.tabular.agnostic.bias import BiasAnalyzer
from omnixai.tests.explainers.tasks import TabularRegression


class TestRegressionBias(unittest.TestCase):

def test_classification_metric(self):
set_random_seed()
task = TabularRegression().train_boston()
predict_function = lambda z: task.model.predict(task.transform.transform(z))

explainer = BiasAnalyzer(
mode="regression",
predict_function=predict_function,
training_data=task.test_data,
training_targets=task.test_targets
)
explanations = explainer.explain(
feature_column="LSTAT",
feature_value_or_threshold=10,
label_value_or_threshold=22
)
print(explanations.get_explanations())
explanations.plotly_plot()


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def test_1(self):
training_data=task.train_data,
predict_function=predict_function,
ignored_features=None,
nsamples=150
nsamples=150,
link="logit"
)

i = 1653
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name="omnixai",
version="1.2.3",
version="1.2.4",
author="Wenzhuo Yang, Hung Le, Tanmay Shivprasad Laud, Silvio Savarese, Steven C.H. Hoi",
description="OmniXAI: An Explainable AI Toolbox",
long_description=open("README.md", "r", encoding="utf-8").read(),
Expand All @@ -43,7 +43,7 @@
"tqdm",
"wheel",
"packaging",
"ipython",
"ipython!=8.7.0",
"tabulate",
"statsmodels>=0.10.1"
],
Expand Down

0 comments on commit cc65dab

Please sign in to comment.