Skip to content

Commit

Permalink
Merge pull request #84 from salesforce/explain_gpt
Browse files Browse the repository at this point in the history
Explain gpt
  • Loading branch information
yangwenzhuo08 authored May 25, 2023
2 parents eb200c6 + 15caeea commit c799391
Show file tree
Hide file tree
Showing 8 changed files with 605 additions and 3 deletions.
29 changes: 28 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
8. [How to Contribute](https://opensource.salesforce.com/OmniXAI/latest/omnixai.html#how-to-contribute)
9. [Technical Report and Citing OmniXAI](#technical-report-and-citing-omnixai)

## What's New

The latest version includes an experimental GPT explainer. This explainer leverages the outcomes
produced by SHAP and MACE to formulate the input prompt for ChatGPT. Subsequently, ChatGPT
analyzes these results and generates the corresponding explanations that provide developers with
a clearer understanding of the rationale behind the model's predictions.

## Introduction

Expand Down Expand Up @@ -67,9 +73,10 @@ We will continue improving this library to make it more comprehensive in the fut
| Permutation explanation | Black box | Global | || | | |
| Feature visualization | Torch or TF | Global | | || | |
| Feature maps | Torch or TF | Local | | || | |
| GPT explainer | Black box | Local | || | | |
| LIME | Black box | Local | |||| |
| SHAP | Black box* | Local | |||||
| What-if | Black box | Local | || | | |
| What-if | Black box | Local | || | | |
| Integrated gradient | Torch or TF | Local | |||| |
| Counterfactual | Black box* | Local | |||||
| Contrastive explanation | Torch or TF | Local | | || | |
Expand All @@ -90,6 +97,10 @@ This [table](https://opensource.salesforce.com/OmniXAI/latest/index.html#compari
shows the comparison between our toolkit/library and other existing XAI toolkits/libraries
in literature.

**OmniXAI also integrates ChatGPT for generating plain text explanations given a classification/regression
model on tabular datasets.** The generated results may not be 100% accurate, but it is worth trying this
explainer (we will continue improving the input prompts).

## Installation

You can install ``omnixai`` from PyPI by calling ``pip install omnixai``. You may install from source by
Expand Down Expand Up @@ -284,6 +295,22 @@ dashboard.show() # Launch the dashboard
After opening the Dash app in the browser, we will see a dashboard showing the explanations:
![alt text](https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo.gif)

You can also use the GPT explainer to generate explanations in text for tabular models:

```python
explainer = TabularExplainer(
explainers=["gpt"], # The GPT explainer to apply
mode="classification", # The task type
data=train_data, # The data for initializing the explainers
model=model, # The ML model to explain
preprocess=lambda z: transformer.transform(z), # Converts raw features into the model inputs
params={
"gpt": {"apikey": "xxxx"}
} # Set the OpenAI API KEY
)
local_explanations = explainer.explain(X=test_instances)
```

For vision tasks, the same interface is used to create explainers and generate explanations.
Let's take an image classification model as an example.

Expand Down
9 changes: 9 additions & 0 deletions docs/omnixai.explainers.tabular.agnostic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ omnixai.explainers.tabular.agnostic package
L2X.l2x
permutation
bias
gpt

omnixai.explainers.tabular.agnostic.lime module
-----------------------------------------------
Expand Down Expand Up @@ -87,3 +88,11 @@ omnixai.explainers.tabular.agnostic.bias module
:members:
:undoc-members:
:show-inheritance:

omnixai.explainers.tabular.agnostic.gpt module
----------------------------------------------

.. automodule:: omnixai.explainers.tabular.agnostic.gpt
:members:
:undoc-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions omnixai/explainers/tabular/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .agnostic.permutation import PermutationImportance
from .agnostic.shap_global import GlobalShapTabular
from .agnostic.bias import BiasAnalyzer
from .agnostic.gpt import GPTExplainer
from .counterfactual.mace.mace import MACEExplainer
from .counterfactual.ce import CounterfactualExplainer
from .counterfactual.knn import KNNCounterfactualExplainer
Expand All @@ -36,6 +37,7 @@
"PermutationImportance",
"GlobalShapTabular",
"BiasAnalyzer",
"GPTExplainer",
"MACEExplainer",
"CounterfactualExplainer",
"KNNCounterfactualExplainer",
Expand Down
153 changes: 153 additions & 0 deletions omnixai/explainers/tabular/agnostic/gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#
# Copyright (c) 2023 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
"""
The explainer based ChatGPT.
"""
import os
import openai
from typing import Callable, List
from omnixai.data.tabular import Tabular
from omnixai.explainers.base import ExplainerBase
from omnixai.explainers.tabular.agnostic.shap import ShapTabular
from omnixai.explainers.tabular.counterfactual.mace.mace import MACEExplainer
from omnixai.explanations.base import PlainText


class GPTExplainer(ExplainerBase):
"""
The explainer based on ChatGPT. The input prompt consists of the feature importance scores
and the counterfactual examples (if used). The explanations will be the text generated by
ChatGPT.
"""
explanation_type = "local"
alias = ["gpt"]

def __init__(
self,
training_data: Tabular,
predict_function: Callable,
apikey: str,
mode: str = "classification",
ignored_features: List = None,
include_counterfactual: bool = True,
openai_model: str = "gpt-3.5-turbo",
**kwargs
):
"""
:param training_data: The data used to initialize a SHAP explainer. ``training_data``
can be the training dataset for training the machine learning model.
:param predict_function: The prediction function corresponding to the model to explain.
When the model is for classification, the outputs of the ``predict_function``
are the class probabilities. When the model is for regression, the outputs of
the ``predict_function`` are the estimated values.
:param apikey: The OpenAI API Key.
:param mode: The task type, e.g., `classification` or `regression`.
:param ignored_features: The features ignored in computing feature importance scores.
:param include_counterfactual: Whether to include counterfactual explanations in the results.
:param openai_model: The model type for chat completion.
:param kwargs: Additional parameters to initialize `shap.KernelExplainer`, e.g., ``nsamples``.
Please refer to the doc of `shap.KernelExplainer`.
"""
super().__init__()
self.apikey = apikey
self.openai_model = openai_model
self.shap_explainer = ShapTabular(
training_data=training_data,
predict_function=predict_function,
mode=mode,
ignored_features=ignored_features,
nsamples=150
)
if include_counterfactual and mode == "classification":
self.mace_explainer = MACEExplainer(
training_data=training_data,
predict_function=predict_function,
mode=mode,
ignored_features=ignored_features,
)
else:
self.mace_explainer = None

@staticmethod
def _generate_prompt(
shap_explanation,
mace_explanation=None,
mode="classification",
top_k=50
):
system_prompt = \
f"You are a helpful assistant for explaining prediction results generated " \
f"by a machine learning {mode} model based on the information provided below. " \
f"Your answers should be accurate and concise."

prompts = []
for i, (feature, value, score) in enumerate(zip(
shap_explanation["features"], shap_explanation["values"], shap_explanation["scores"])):
if i < top_k:
prompts.append('{}. "{} = {}": {:.4f}'.format(i, feature, value, score))
context_prompt = 'Firstly, given the following feature importance scores in the format ' \
'"<feature name>: <feature importance score>":\n\n' + "\n".join(prompts)

if mode == "classification":
question_prompt = f"Please explain why this example is classified as " \
f"label_{shap_explanation['target_label']}."
else:
question_prompt = "Please explain why this example has this predicted value."
context_prompt += f"\n\n{question_prompt}" + "\nYour answer should be concise and accurate."

if mace_explanation is not None and mace_explanation["counterfactual"] is not None:
df = mace_explanation["query"]
cfs = mace_explanation["counterfactual"]
feature_names = list(df.columns)
feature_values = df.values[0]
cf_label = cfs["label"].values[0]

prompts = []
for i, values in enumerate(cfs.values):
changed_features = []
for name, x, y in zip(feature_names, feature_values, values):
if name != "label" and x != y:
changed_features.append(f'"{name}" = "{y}"')
if len(changed_features) > 0:
prompts.append("{}. If {}, then the predicted label will be label_{} instead of label_{}".format(
i, " and ".join(changed_features), cf_label, shap_explanation['target_label']))
mace_prompt = \
"Then given the following results generated by the MACE counterfactual explainer:" \
"\n\n{}".format("\n".join(prompts[:2]))
context_prompt += \
"\n\n" + mace_prompt + "\n\n" + \
"Please show how to change feature values to change the predicted label. " \
"\nYour answer should be concise and accurate."

return system_prompt, context_prompt

@staticmethod
def _api_call(prompt, apikey, model):
openai.api_key = os.getenv("OPENAI_API_KEY") if not apikey else apikey
if not openai.api_key:
raise RuntimeError("Please set your OpenAI API KEY.")

completion = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": prompt[0]},
{"role": "user", "content": prompt[1]}
]
)
return completion.choices[0].message.content

def explain(self, X, **kwargs) -> PlainText:
explanations = PlainText()
shap_explanations = self.shap_explainer.explain(X, nsamples=100)
mace_explanations = self.mace_explainer.explain(X) if self.mace_explainer is not None else None

for i, e in enumerate(shap_explanations.get_explanations()):
mace = mace_explanations.get_explanations()[i] if mace_explanations is not None else None
input_prompt = self._generate_prompt(e, mace_explanation=mace)
explanation = self._api_call(input_prompt, self.apikey, self.openai_model)
explanations.add(instance=X.iloc(i).to_pd(), text=explanation)
return explanations
64 changes: 64 additions & 0 deletions omnixai/explanations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,67 @@ def from_dict(cls, d):
self = cls.__new__(PredictedResults)
self.results = d["results"]
return self


class PlainText(ExplanationBase):
"""
The class for plain text explanations.
"""

def __init__(self, explanations=None):
"""
:param explanations: The explanation results for initializing ``FeatureImportance``,
which is optional.
"""
super().__init__()
self.explanations = [] if explanations is None else explanations

def __repr__(self):
return repr(self.explanations)

def __getitem__(self, i: int):
assert i < len(self.explanations)
return PlainText(explanations=[self.explanations[i]])

def add(self, instance, text, **kwargs):
"""
Adds the generated explanation corresponding to one instance.
:param instance: The instance to explain.
:param text: The text explanation of the given instance.
"""
e = {"instance": instance, "text": text}
e.update(kwargs)
self.explanations.append(e)

def get_explanations(self, index=None):
"""
Gets the generated explanations.
:param index: The index of an explanation result stored in ``PlainText``.
When ``index`` is None, the function returns a list of all the explanations.
:return: The explanation for one specific instance (a dict)
or the explanations for all the instances (a list of dicts).
Each dict has the following format: `{"instance": the input instance,
"text": the corresponding explanations in plain text}`.
:rtype: Union[Dict, List]
"""
return self.explanations if index is None else self.explanations[index]

def plot(self, **kwargs):
raise NotImplementedError

def plotly_plot(self, **kwargs):
raise NotImplementedError

def ipython_plot(self, **kwargs):
raise NotImplementedError

@classmethod
def from_dict(cls, d):
import pandas as pd
explanations = []
for e in d["explanations"]:
e["instance"] = pd.DataFrame.from_dict(e["instance"])
explanations.append(e)
return PlainText(explanations=explanations)
38 changes: 38 additions & 0 deletions omnixai/tests/explainers/gpt/gpt_explainer_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Copyright (c) 2023 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
import os
import unittest
import pprint
from omnixai.utils.misc import set_random_seed
from omnixai.explainers.tabular.agnostic.gpt import GPTExplainer
from omnixai.tests.explainers.tasks import TabularClassification


class TestGPTExplainer(unittest.TestCase):

def test(self):
base_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
task = TabularClassification(base_folder).train_adult(num_training_samples=2000)
predict_function = lambda z: task.model.predict_proba(task.transform.transform(z))

set_random_seed()
explainer = GPTExplainer(
training_data=task.train_data,
predict_function=predict_function,
ignored_features=None,
apikey="xxx"
)

i = 1653
test_x = task.test_data.iloc(i)
print(predict_function(test_x))
explanations = explainer.explain(test_x)
pprint.pprint(explanations.get_explanations(index=0)["text"])


if __name__ == "__main__":
unittest.main()
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

setup(
name="omnixai",
version="1.2.5",
version="1.3.0",
author="Wenzhuo Yang, Hung Le, Tanmay Shivprasad Laud, Silvio Savarese, Steven C.H. Hoi",
description="OmniXAI: An Explainable AI Toolbox",
long_description=open("README.md", "r", encoding="utf-8").read(),
Expand All @@ -33,7 +33,7 @@
package_dir={"omnixai": "omnixai"},
package_data={"omnixai": ["visualization/assets/*"]},
install_requires=[
"numpy>=1.17",
"numpy>=1.17,<1.24",
"pandas>=1.1.0",
"scikit-learn>=0.24,<1.2",
"scipy>=1.5.0",
Expand All @@ -51,6 +51,7 @@
"ipython>=8.10.0",
"tabulate",
"statsmodels>=0.10.1",
"openai"
],
extras_require=extras_require,
python_requires=">=3.7,<4",
Expand Down
Loading

0 comments on commit c799391

Please sign in to comment.