diff --git a/README.md b/README.md index db5eb574..28a6f0ac 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,12 @@ 8. [How to Contribute](https://opensource.salesforce.com/OmniXAI/latest/omnixai.html#how-to-contribute) 9. [Technical Report and Citing OmniXAI](#technical-report-and-citing-omnixai) +## What's New + +The latest version includes an experimental GPT explainer. This explainer leverages the outcomes +produced by SHAP and MACE to formulate the input prompt for ChatGPT. Subsequently, ChatGPT +analyzes these results and generates the corresponding explanations that provide developers with +a clearer understanding of the rationale behind the model's predictions. ## Introduction @@ -67,9 +73,10 @@ We will continue improving this library to make it more comprehensive in the fut | Permutation explanation | Black box | Global | | ✅ | | | | | Feature visualization | Torch or TF | Global | | | ✅ | | | | Feature maps | Torch or TF | Local | | | ✅ | | | +| GPT explainer | Black box | Local | | ✅ | | | | | LIME | Black box | Local | | ✅ | ✅ | ✅ | | | SHAP | Black box* | Local | | ✅ | ✅ | ✅ | ✅ | -| What-if | Black box | Local | | ✅ | | | | +| What-if | Black box | Local | | ✅ | | | | | Integrated gradient | Torch or TF | Local | | ✅ | ✅ | ✅ | | | Counterfactual | Black box* | Local | | ✅ | ✅ | ✅ | ✅ | | Contrastive explanation | Torch or TF | Local | | | ✅ | | | @@ -90,6 +97,10 @@ This [table](https://opensource.salesforce.com/OmniXAI/latest/index.html#compari shows the comparison between our toolkit/library and other existing XAI toolkits/libraries in literature. +**OmniXAI also integrates ChatGPT for generating plain text explanations given a classification/regression +model on tabular datasets.** The generated results may not be 100% accurate, but it is worth trying this +explainer (we will continue improving the input prompts). + ## Installation You can install ``omnixai`` from PyPI by calling ``pip install omnixai``. You may install from source by @@ -284,6 +295,22 @@ dashboard.show() # Launch the dashboard After opening the Dash app in the browser, we will see a dashboard showing the explanations: ![alt text](https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo.gif) +You can also use the GPT explainer to generate explanations in text for tabular models: + +```python +explainer = TabularExplainer( + explainers=["gpt"], # The GPT explainer to apply + mode="classification", # The task type + data=train_data, # The data for initializing the explainers + model=model, # The ML model to explain + preprocess=lambda z: transformer.transform(z), # Converts raw features into the model inputs + params={ + "gpt": {"apikey": "xxxx"} + } # Set the OpenAI API KEY +) +local_explanations = explainer.explain(X=test_instances) +``` + For vision tasks, the same interface is used to create explainers and generate explanations. Let's take an image classification model as an example. diff --git a/docs/omnixai.explainers.tabular.agnostic.rst b/docs/omnixai.explainers.tabular.agnostic.rst index 75d914cf..6ae73b11 100644 --- a/docs/omnixai.explainers.tabular.agnostic.rst +++ b/docs/omnixai.explainers.tabular.agnostic.rst @@ -15,6 +15,7 @@ omnixai.explainers.tabular.agnostic package L2X.l2x permutation bias + gpt omnixai.explainers.tabular.agnostic.lime module ----------------------------------------------- @@ -87,3 +88,11 @@ omnixai.explainers.tabular.agnostic.bias module :members: :undoc-members: :show-inheritance: + +omnixai.explainers.tabular.agnostic.gpt module +---------------------------------------------- + +.. automodule:: omnixai.explainers.tabular.agnostic.gpt + :members: + :undoc-members: + :show-inheritance: diff --git a/omnixai/explainers/tabular/__init__.py b/omnixai/explainers/tabular/__init__.py index f25b587d..a42f3b1d 100644 --- a/omnixai/explainers/tabular/__init__.py +++ b/omnixai/explainers/tabular/__init__.py @@ -14,6 +14,7 @@ from .agnostic.permutation import PermutationImportance from .agnostic.shap_global import GlobalShapTabular from .agnostic.bias import BiasAnalyzer +from .agnostic.gpt import GPTExplainer from .counterfactual.mace.mace import MACEExplainer from .counterfactual.ce import CounterfactualExplainer from .counterfactual.knn import KNNCounterfactualExplainer @@ -36,6 +37,7 @@ "PermutationImportance", "GlobalShapTabular", "BiasAnalyzer", + "GPTExplainer", "MACEExplainer", "CounterfactualExplainer", "KNNCounterfactualExplainer", diff --git a/omnixai/explainers/tabular/agnostic/gpt.py b/omnixai/explainers/tabular/agnostic/gpt.py new file mode 100644 index 00000000..b04fbe6e --- /dev/null +++ b/omnixai/explainers/tabular/agnostic/gpt.py @@ -0,0 +1,153 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +The explainer based ChatGPT. +""" +import os +import openai +from typing import Callable, List +from omnixai.data.tabular import Tabular +from omnixai.explainers.base import ExplainerBase +from omnixai.explainers.tabular.agnostic.shap import ShapTabular +from omnixai.explainers.tabular.counterfactual.mace.mace import MACEExplainer +from omnixai.explanations.base import PlainText + + +class GPTExplainer(ExplainerBase): + """ + The explainer based on ChatGPT. The input prompt consists of the feature importance scores + and the counterfactual examples (if used). The explanations will be the text generated by + ChatGPT. + """ + explanation_type = "local" + alias = ["gpt"] + + def __init__( + self, + training_data: Tabular, + predict_function: Callable, + apikey: str, + mode: str = "classification", + ignored_features: List = None, + include_counterfactual: bool = True, + openai_model: str = "gpt-3.5-turbo", + **kwargs + ): + """ + :param training_data: The data used to initialize a SHAP explainer. ``training_data`` + can be the training dataset for training the machine learning model. + :param predict_function: The prediction function corresponding to the model to explain. + When the model is for classification, the outputs of the ``predict_function`` + are the class probabilities. When the model is for regression, the outputs of + the ``predict_function`` are the estimated values. + :param apikey: The OpenAI API Key. + :param mode: The task type, e.g., `classification` or `regression`. + :param ignored_features: The features ignored in computing feature importance scores. + :param include_counterfactual: Whether to include counterfactual explanations in the results. + :param openai_model: The model type for chat completion. + :param kwargs: Additional parameters to initialize `shap.KernelExplainer`, e.g., ``nsamples``. + Please refer to the doc of `shap.KernelExplainer`. + """ + super().__init__() + self.apikey = apikey + self.openai_model = openai_model + self.shap_explainer = ShapTabular( + training_data=training_data, + predict_function=predict_function, + mode=mode, + ignored_features=ignored_features, + nsamples=150 + ) + if include_counterfactual and mode == "classification": + self.mace_explainer = MACEExplainer( + training_data=training_data, + predict_function=predict_function, + mode=mode, + ignored_features=ignored_features, + ) + else: + self.mace_explainer = None + + @staticmethod + def _generate_prompt( + shap_explanation, + mace_explanation=None, + mode="classification", + top_k=50 + ): + system_prompt = \ + f"You are a helpful assistant for explaining prediction results generated " \ + f"by a machine learning {mode} model based on the information provided below. " \ + f"Your answers should be accurate and concise." + + prompts = [] + for i, (feature, value, score) in enumerate(zip( + shap_explanation["features"], shap_explanation["values"], shap_explanation["scores"])): + if i < top_k: + prompts.append('{}. "{} = {}": {:.4f}'.format(i, feature, value, score)) + context_prompt = 'Firstly, given the following feature importance scores in the format ' \ + '": ":\n\n' + "\n".join(prompts) + + if mode == "classification": + question_prompt = f"Please explain why this example is classified as " \ + f"label_{shap_explanation['target_label']}." + else: + question_prompt = "Please explain why this example has this predicted value." + context_prompt += f"\n\n{question_prompt}" + "\nYour answer should be concise and accurate." + + if mace_explanation is not None and mace_explanation["counterfactual"] is not None: + df = mace_explanation["query"] + cfs = mace_explanation["counterfactual"] + feature_names = list(df.columns) + feature_values = df.values[0] + cf_label = cfs["label"].values[0] + + prompts = [] + for i, values in enumerate(cfs.values): + changed_features = [] + for name, x, y in zip(feature_names, feature_values, values): + if name != "label" and x != y: + changed_features.append(f'"{name}" = "{y}"') + if len(changed_features) > 0: + prompts.append("{}. If {}, then the predicted label will be label_{} instead of label_{}".format( + i, " and ".join(changed_features), cf_label, shap_explanation['target_label'])) + mace_prompt = \ + "Then given the following results generated by the MACE counterfactual explainer:" \ + "\n\n{}".format("\n".join(prompts[:2])) + context_prompt += \ + "\n\n" + mace_prompt + "\n\n" + \ + "Please show how to change feature values to change the predicted label. " \ + "\nYour answer should be concise and accurate." + + return system_prompt, context_prompt + + @staticmethod + def _api_call(prompt, apikey, model): + openai.api_key = os.getenv("OPENAI_API_KEY") if not apikey else apikey + if not openai.api_key: + raise RuntimeError("Please set your OpenAI API KEY.") + + completion = openai.ChatCompletion.create( + model=model, + messages=[ + {"role": "system", "content": prompt[0]}, + {"role": "user", "content": prompt[1]} + ] + ) + return completion.choices[0].message.content + + def explain(self, X, **kwargs) -> PlainText: + explanations = PlainText() + shap_explanations = self.shap_explainer.explain(X, nsamples=100) + mace_explanations = self.mace_explainer.explain(X) if self.mace_explainer is not None else None + + for i, e in enumerate(shap_explanations.get_explanations()): + mace = mace_explanations.get_explanations()[i] if mace_explanations is not None else None + input_prompt = self._generate_prompt(e, mace_explanation=mace) + explanation = self._api_call(input_prompt, self.apikey, self.openai_model) + explanations.add(instance=X.iloc(i).to_pd(), text=explanation) + return explanations diff --git a/omnixai/explanations/base.py b/omnixai/explanations/base.py index 6014280d..0404df4f 100644 --- a/omnixai/explanations/base.py +++ b/omnixai/explanations/base.py @@ -297,3 +297,67 @@ def from_dict(cls, d): self = cls.__new__(PredictedResults) self.results = d["results"] return self + + +class PlainText(ExplanationBase): + """ + The class for plain text explanations. + """ + + def __init__(self, explanations=None): + """ + :param explanations: The explanation results for initializing ``FeatureImportance``, + which is optional. + """ + super().__init__() + self.explanations = [] if explanations is None else explanations + + def __repr__(self): + return repr(self.explanations) + + def __getitem__(self, i: int): + assert i < len(self.explanations) + return PlainText(explanations=[self.explanations[i]]) + + def add(self, instance, text, **kwargs): + """ + Adds the generated explanation corresponding to one instance. + + :param instance: The instance to explain. + :param text: The text explanation of the given instance. + """ + e = {"instance": instance, "text": text} + e.update(kwargs) + self.explanations.append(e) + + def get_explanations(self, index=None): + """ + Gets the generated explanations. + + :param index: The index of an explanation result stored in ``PlainText``. + When ``index`` is None, the function returns a list of all the explanations. + :return: The explanation for one specific instance (a dict) + or the explanations for all the instances (a list of dicts). + Each dict has the following format: `{"instance": the input instance, + "text": the corresponding explanations in plain text}`. + :rtype: Union[Dict, List] + """ + return self.explanations if index is None else self.explanations[index] + + def plot(self, **kwargs): + raise NotImplementedError + + def plotly_plot(self, **kwargs): + raise NotImplementedError + + def ipython_plot(self, **kwargs): + raise NotImplementedError + + @classmethod + def from_dict(cls, d): + import pandas as pd + explanations = [] + for e in d["explanations"]: + e["instance"] = pd.DataFrame.from_dict(e["instance"]) + explanations.append(e) + return PlainText(explanations=explanations) diff --git a/omnixai/tests/explainers/gpt/gpt_explainer_classification.py b/omnixai/tests/explainers/gpt/gpt_explainer_classification.py new file mode 100644 index 00000000..debeab67 --- /dev/null +++ b/omnixai/tests/explainers/gpt/gpt_explainer_classification.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import os +import unittest +import pprint +from omnixai.utils.misc import set_random_seed +from omnixai.explainers.tabular.agnostic.gpt import GPTExplainer +from omnixai.tests.explainers.tasks import TabularClassification + + +class TestGPTExplainer(unittest.TestCase): + + def test(self): + base_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") + task = TabularClassification(base_folder).train_adult(num_training_samples=2000) + predict_function = lambda z: task.model.predict_proba(task.transform.transform(z)) + + set_random_seed() + explainer = GPTExplainer( + training_data=task.train_data, + predict_function=predict_function, + ignored_features=None, + apikey="xxx" + ) + + i = 1653 + test_x = task.test_data.iloc(i) + print(predict_function(test_x)) + explanations = explainer.explain(test_x) + pprint.pprint(explanations.get_explanations(index=0)["text"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/setup.py b/setup.py index 12ac7056..ae6fa504 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ setup( name="omnixai", - version="1.2.5", + version="1.3.0", author="Wenzhuo Yang, Hung Le, Tanmay Shivprasad Laud, Silvio Savarese, Steven C.H. Hoi", description="OmniXAI: An Explainable AI Toolbox", long_description=open("README.md", "r", encoding="utf-8").read(), @@ -33,7 +33,7 @@ package_dir={"omnixai": "omnixai"}, package_data={"omnixai": ["visualization/assets/*"]}, install_requires=[ - "numpy>=1.17", + "numpy>=1.17,<1.24", "pandas>=1.1.0", "scikit-learn>=0.24,<1.2", "scipy>=1.5.0", @@ -51,6 +51,7 @@ "ipython>=8.10.0", "tabulate", "statsmodels>=0.10.1", + "openai" ], extras_require=extras_require, python_requires=">=3.7,<4", diff --git a/tutorials/tabular/gpt.ipynb b/tutorials/tabular/gpt.ipynb new file mode 100644 index 00000000..e2af3665 --- /dev/null +++ b/tutorials/tabular/gpt.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### GPT explainer for income prediction" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sklearn\n", + "import xgboost\n", + "import numpy as np\n", + "import pandas as pd\n", + "from omnixai.data.tabular import Tabular\n", + "from omnixai.preprocessing.tabular import TabularTransform\n", + "from omnixai.explainers.tabular import GPTExplainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The dataset used in this example is for income prediction (https://archive.ics.uci.edu/ml/datasets/adult). We recommend using `Tabular` to represent a tabular dataset, which can be constructed from a pandas dataframe or a numpy array. To create a `Tabular` instance given a numpy array, one needs to specify the data, the feature names, the categorical feature names (if exists) and the target/label column name (if exists)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Age Workclass fnlwgt Education Education-Num \\\n", + "0 39 State-gov 77516 Bachelors 13 \n", + "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", + "2 38 Private 215646 HS-grad 9 \n", + "3 53 Private 234721 11th 7 \n", + "4 28 Private 338409 Bachelors 13 \n", + "... .. ... ... ... ... \n", + "32556 27 Private 257302 Assoc-acdm 12 \n", + "32557 40 Private 154374 HS-grad 9 \n", + "32558 58 Private 151910 HS-grad 9 \n", + "32559 22 Private 201490 HS-grad 9 \n", + "32560 52 Self-emp-inc 287927 HS-grad 9 \n", + "\n", + " Marital Status Occupation Relationship Race Sex \\\n", + "0 Never-married Adm-clerical Not-in-family White Male \n", + "1 Married-civ-spouse Exec-managerial Husband White Male \n", + "2 Divorced Handlers-cleaners Not-in-family White Male \n", + "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", + "4 Married-civ-spouse Prof-specialty Wife Black Female \n", + "... ... ... ... ... ... \n", + "32556 Married-civ-spouse Tech-support Wife White Female \n", + "32557 Married-civ-spouse Machine-op-inspct Husband White Male \n", + "32558 Widowed Adm-clerical Unmarried White Female \n", + "32559 Never-married Adm-clerical Own-child White Male \n", + "32560 Married-civ-spouse Exec-managerial Wife White Female \n", + "\n", + " Capital Gain Capital Loss Hours per week Country label \n", + "0 2174 0 40 United-States <=50K \n", + "1 0 0 13 United-States <=50K \n", + "2 0 0 40 United-States <=50K \n", + "3 0 0 40 United-States <=50K \n", + "4 0 0 40 Cuba <=50K \n", + "... ... ... ... ... ... \n", + "32556 0 0 38 United-States <=50K \n", + "32557 0 0 40 United-States >50K \n", + "32558 0 0 40 United-States <=50K \n", + "32559 0 0 20 United-States <=50K \n", + "32560 15024 0 40 United-States >50K \n", + "\n", + "[32561 rows x 15 columns]\n" + ] + } + ], + "source": [ + "feature_names = [\n", + " \"Age\", \"Workclass\", \"fnlwgt\", \"Education\",\n", + " \"Education-Num\", \"Marital Status\", \"Occupation\",\n", + " \"Relationship\", \"Race\", \"Sex\", \"Capital Gain\",\n", + " \"Capital Loss\", \"Hours per week\", \"Country\", \"label\"\n", + "]\n", + "data = np.genfromtxt(os.path.join('../data', 'adult.data'), delimiter=', ', dtype=str)\n", + "tabular_data = Tabular(\n", + " data,\n", + " feature_columns=feature_names,\n", + " categorical_columns=[feature_names[i] for i in [1, 3, 5, 6, 7, 8, 9, 13]],\n", + " target_column='label'\n", + ")\n", + "print(tabular_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`TabularTransform` is a special transform designed for tabular data. By default, it converts categorical features into one-hot encoding, and keeps continuous-valued features (if one wants to normalize continuous-valued features, set the parameter `cont_transform` in `TabularTransform` to `Standard` or `MinMax`). The `transform` method of `TabularTransform` will transform a `Tabular` instance into a numpy array. If the `Tabular` instance has a target/label column, the last column of the transformed numpy array will be the target/label. \n", + "\n", + "If one wants some other transformations that are not supported in the library, one can simply convert the `Tabular` instance into a pandas dataframe by calling `Tabular.to_pd()` and try different transformations with it.\n", + "\n", + "After data preprocessing, we can train a XGBoost classifier for this task (one may try other classifiers). " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training data shape: (26048, 108)\n", + "Test data shape: (6513, 108)\n", + "Test accuracy: 0.8668816213726394\n" + ] + } + ], + "source": [ + "np.random.seed(1)\n", + "transformer = TabularTransform().fit(tabular_data)\n", + "class_names = transformer.class_names\n", + "x = transformer.transform(tabular_data)\n", + "train, test, labels_train, labels_test = \\\n", + " sklearn.model_selection.train_test_split(x[:, :-1], x[:, -1], train_size=0.80)\n", + "print('Training data shape: {}'.format(train.shape))\n", + "print('Test data shape: {}'.format(test.shape))\n", + "\n", + "gbtree = xgboost.XGBClassifier(n_estimators=300, max_depth=5)\n", + "gbtree.fit(train, labels_train)\n", + "print('Test accuracy: {}'.format(\n", + " sklearn.metrics.accuracy_score(labels_test, gbtree.predict(test))))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The prediction function takes a `Tabular` instance as its inputs, and outputs the class probabilities for classification tasks or the estimated values for regression tasks. In this example, we simply call `transformer.transform` to do data preprocessing followed by the prediction function of `gbtree`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "predict_function=lambda z: gbtree.predict_proba(transformer.transform(z))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To initialize a GPT explainer, we need to set:\n", + " \n", + " - `training_data`: The data used to initialize a SHAP explainer. ``training_data`` can be the training dataset for training the machine learning model. If the training dataset is too large, ``training_data`` can be a subset of it by applying `omnixai.sampler.tabular.Sampler.subsample`.\n", + " - `predict_function`: The prediction function corresponding to the model.\n", + " - `mode`: The task type, e.g., \"classification\" or \"regression\".\n", + " - `apikey`: The OpenAI API key." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using 150 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n" + ] + } + ], + "source": [ + "explainer = GPTExplainer(\n", + " training_data=tabular_data,\n", + " predict_function=predict_function,\n", + " apikey=\"sk-xxx\"\n", + ")\n", + "# Apply an inverse transform, i.e., converting the numpy array back to `Tabular`\n", + "test_instances = transformer.invert(test)\n", + "test_x = test_instances[1653]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We are now ready to generate explanations:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.0365750789642334, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": 10, + "postfix": null, + "prefix": "", + "rate": null, + "total": 1, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "0f3d5b4cb77b4364b694a046df79d2d2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00