Skip to content

Commit

Permalink
Updated 12 example notebooks to run on Colab. Fixes Trusted-AI#385
Browse files Browse the repository at this point in the history
Signed-off-by: Clayton O'Dell <codell2@nd.edu>
  • Loading branch information
codell2 committed Dec 6, 2024
1 parent f5a019c commit 22a79db
Show file tree
Hide file tree
Showing 12 changed files with 12,046 additions and 10,796 deletions.
1,566 changes: 819 additions & 747 deletions examples/demo_FACTS.ipynb

Large diffs are not rendered by default.

1,746 changes: 890 additions & 856 deletions examples/demo_gerryfair.ipynb

Large diffs are not rendered by default.

554 changes: 300 additions & 254 deletions examples/demo_json_explainers.ipynb

Large diffs are not rendered by default.

919 changes: 487 additions & 432 deletions examples/demo_lfr.ipynb

Large diffs are not rendered by default.

2,533 changes: 1,342 additions & 1,191 deletions examples/demo_mdss_classifier_metric.ipynb

Large diffs are not rendered by default.

3,256 changes: 1,722 additions & 1,534 deletions examples/demo_mdss_detector.ipynb

Large diffs are not rendered by default.

937 changes: 499 additions & 438 deletions examples/demo_meta_classifier.ipynb

Large diffs are not rendered by default.

2,798 changes: 1,452 additions & 1,346 deletions examples/demo_ot_metric.ipynb

Large diffs are not rendered by default.

1,582 changes: 833 additions & 749 deletions examples/demo_reject_option_classification.ipynb

Large diffs are not rendered by default.

1,915 changes: 1,013 additions & 902 deletions examples/demo_reweighing_preproc.ipynb

Large diffs are not rendered by default.

369 changes: 194 additions & 175 deletions examples/demo_short_gerryfair_test.ipynb
Original file line number Diff line number Diff line change
@@ -1,183 +1,202 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [
"cells": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 1, error: 0.229, fairness violation: 0.05428400000000001, violated group size: 0.249\n",
"iteration: 2, error: 0.3645, fairness violation: 0.027142000000000006, violated group size: 0.249\n",
"iteration: 3, error: 0.4096666666666666, fairness violation: 0.01809466666666667, violated group size: 0.251\n",
"iteration: 4, error: 0.43225, fairness violation: 0.013571000000000003, violated group size: 0.249\n",
"iteration: 5, error: 0.44580000000000014, fairness violation: 0.0108568, violated group size: 0.251\n",
"iteration: 6, error: 0.4548333333333334, fairness violation: 0.009047333333333338, violated group size: 0.251\n",
"iteration: 7, error: 0.46128571428571435, fairness violation: 0.007754857142857144, violated group size: 0.251\n",
"iteration: 8, error: 0.466125, fairness violation: 0.006785500000000003, violated group size: 0.251\n",
"iteration: 9, error: 0.469888888888889, fairness violation: 0.006031555555555558, violated group size: 0.249\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"import sys\n",
"sys.path.append(\"../\")\n",
"from aif360.algorithms.inprocessing import GerryFairClassifier\n",
"from aif360.algorithms.inprocessing.gerryfair.clean import array_to_tuple\n",
"from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult\n",
"from sklearn import svm\n",
"from sklearn import tree\n",
"from sklearn.kernel_ridge import KernelRidge\n",
"from sklearn import linear_model\n",
"from aif360.metrics import BinaryLabelDatasetMetric\n",
"from IPython.display import Image\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# load data set\n",
"data_set = load_preproc_data_adult(sub_samp=1000, balance=True)\n",
"max_iterations = 10\n",
"C = 100\n",
"print_flag = True\n",
"gamma = .005\n",
"\n",
"fair_model = GerryFairClassifier(C=C, printflag=print_flag, gamma=gamma, fairness_def='FP',\n",
" max_iters=max_iterations, heatmapflag=False)\n",
"# fit method\n",
"fair_model.fit(data_set, early_termination=True)\n",
"\n",
"# predict method. If threshold in (0, 1) produces binary predictions\n",
"dataset_yhat = fair_model.predict(data_set, threshold=False)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [
"cell_type": "markdown",
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/main/examples/demo_short_gerryfair_test.ipynb)"
],
"metadata": {
"id": "qsaaqV7jGtYV"
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0060315555555555565\n"
]
}
],
"source": [
"# auditing \n",
"\n",
"gerry_metric = BinaryLabelDatasetMetric(data_set)\n",
"gamma_disparity = gerry_metric.rich_subgroup(array_to_tuple(dataset_yhat.labels), 'FP')\n",
"print(gamma_disparity)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": false
},
"id": "9Cb8bfPXGUa0",
"outputId": "165bfd3e-1281-4b34-d21d-137182075fab"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 1, error: 0.229, fairness violation: 0.05428400000000001, violated group size: 0.249\n",
"iteration: 2, error: 0.3645, fairness violation: 0.027142000000000006, violated group size: 0.249\n",
"iteration: 3, error: 0.4096666666666666, fairness violation: 0.01809466666666667, violated group size: 0.251\n",
"iteration: 4, error: 0.43225, fairness violation: 0.013571000000000003, violated group size: 0.249\n",
"iteration: 5, error: 0.44580000000000014, fairness violation: 0.0108568, violated group size: 0.251\n",
"iteration: 6, error: 0.4548333333333334, fairness violation: 0.009047333333333338, violated group size: 0.251\n",
"iteration: 7, error: 0.46128571428571435, fairness violation: 0.007754857142857144, violated group size: 0.251\n",
"iteration: 8, error: 0.466125, fairness violation: 0.006785500000000003, violated group size: 0.251\n",
"iteration: 9, error: 0.469888888888889, fairness violation: 0.006031555555555558, violated group size: 0.249\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"import sys\n",
"sys.path.append(\"../\")\n",
"from aif360.algorithms.inprocessing import GerryFairClassifier\n",
"from aif360.algorithms.inprocessing.gerryfair.clean import array_to_tuple\n",
"from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult\n",
"from sklearn import svm\n",
"from sklearn import tree\n",
"from sklearn.kernel_ridge import KernelRidge\n",
"from sklearn import linear_model\n",
"from aif360.metrics import BinaryLabelDatasetMetric\n",
"from IPython.display import Image\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# load data set\n",
"data_set = load_preproc_data_adult(sub_samp=1000, balance=True)\n",
"max_iterations = 10\n",
"C = 100\n",
"print_flag = True\n",
"gamma = .005\n",
"\n",
"fair_model = GerryFairClassifier(C=C, printflag=print_flag, gamma=gamma, fairness_def='FP',\n",
" max_iters=max_iterations, heatmapflag=False)\n",
"# fit method\n",
"fair_model.fit(data_set, early_termination=True)\n",
"\n",
"# predict method. If threshold in (0, 1) produces binary predictions\n",
"dataset_yhat = fair_model.predict(data_set, threshold=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": false
},
"id": "cCdltbMWGUa2",
"outputId": "0ab1dd98-6cae-45a1-e51c-41ef9b7193df"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0060315555555555565\n"
]
}
],
"source": [
"# auditing\n",
"\n",
"gerry_metric = BinaryLabelDatasetMetric(data_set)\n",
"gamma_disparity = gerry_metric.rich_subgroup(array_to_tuple(dataset_yhat.labels), 'FP')\n",
"print(gamma_disparity)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Curr Predictor: Linear\n",
"Curr Predictor: SVR\n",
"Curr Predictor: Tree\n",
"Curr Predictor: Kernel\n"
]
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": false
},
"id": "on2jaBgsGUa3",
"outputId": "ed58df2b-2288-491b-bf07-ad5d783baa06"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Curr Predictor: Linear\n",
"Curr Predictor: SVR\n",
"Curr Predictor: Tree\n",
"Curr Predictor: Kernel\n"
]
}
],
"source": [
"# set to 10 iterations for fast running of notebook - set >= 1000 when running real experiments\n",
"# tests learning with different hypothesis classes\n",
"pareto_iters = 10\n",
"def multiple_classifiers_pareto(dataset, gamma_list=[0.002, 0.005, 0.01], save_results=False, iters=pareto_iters):\n",
"\n",
" ln_predictor = linear_model.LinearRegression()\n",
" svm_predictor = svm.LinearSVR()\n",
" tree_predictor = tree.DecisionTreeRegressor(max_depth=3)\n",
" kernel_predictor = KernelRidge(alpha=1.0, gamma=1.0, kernel='rbf')\n",
" predictor_dict = {'Linear': {'predictor': ln_predictor, 'iters': iters},\n",
" 'SVR': {'predictor': svm_predictor, 'iters': iters},\n",
" 'Tree': {'predictor': tree_predictor, 'iters': iters},\n",
" 'Kernel': {'predictor': kernel_predictor, 'iters': iters}}\n",
"\n",
" results_dict = {}\n",
"\n",
" for pred in predictor_dict:\n",
" print('Curr Predictor: {}'.format(pred))\n",
" predictor = predictor_dict[pred]['predictor']\n",
" max_iters = predictor_dict[pred]['iters']\n",
" fair_clf = GerryFairClassifier(C=100, printflag=True, gamma=1, predictor=predictor, max_iters=max_iters)\n",
" fair_clf.printflag = False\n",
" fair_clf.max_iters=max_iters\n",
" errors, fp_violations, fn_violations = fair_clf.pareto(dataset, gamma_list)\n",
" results_dict[pred] = {'errors': errors, 'fp_violations': fp_violations, 'fn_violations': fn_violations}\n",
" if save_results:\n",
" pickle.dump(results_dict, open('results_dict_' + str(gamma_list) + '_gammas' + str(gamma_list) + '.pkl', 'wb'))\n",
"\n",
"multiple_classifiers_pareto(data_set)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": false
},
"id": "tgGHGbQZGUa3"
},
"outputs": [],
"source": []
}
],
"source": [
"# set to 10 iterations for fast running of notebook - set >= 1000 when running real experiments\n",
"# tests learning with different hypothesis classes\n",
"pareto_iters = 10\n",
"def multiple_classifiers_pareto(dataset, gamma_list=[0.002, 0.005, 0.01], save_results=False, iters=pareto_iters):\n",
"\n",
" ln_predictor = linear_model.LinearRegression()\n",
" svm_predictor = svm.LinearSVR()\n",
" tree_predictor = tree.DecisionTreeRegressor(max_depth=3)\n",
" kernel_predictor = KernelRidge(alpha=1.0, gamma=1.0, kernel='rbf')\n",
" predictor_dict = {'Linear': {'predictor': ln_predictor, 'iters': iters},\n",
" 'SVR': {'predictor': svm_predictor, 'iters': iters},\n",
" 'Tree': {'predictor': tree_predictor, 'iters': iters},\n",
" 'Kernel': {'predictor': kernel_predictor, 'iters': iters}}\n",
"\n",
" results_dict = {}\n",
"\n",
" for pred in predictor_dict:\n",
" print('Curr Predictor: {}'.format(pred))\n",
" predictor = predictor_dict[pred]['predictor']\n",
" max_iters = predictor_dict[pred]['iters']\n",
" fair_clf = GerryFairClassifier(C=100, printflag=True, gamma=1, predictor=predictor, max_iters=max_iters)\n",
" fair_clf.printflag = False\n",
" fair_clf.max_iters=max_iters\n",
" errors, fp_violations, fn_violations = fair_clf.pareto(dataset, gamma_list)\n",
" results_dict[pred] = {'errors': errors, 'fp_violations': fp_violations, 'fn_violations': fn_violations}\n",
" if save_results:\n",
" pickle.dump(results_dict, open('results_dict_' + str(gamma_list) + '_gammas' + str(gamma_list) + '.pkl', 'wb'))\n",
"\n",
"multiple_classifiers_pareto(data_set)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
},
"pycharm": {
"is_executing": false
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
},
"colab": {
"provenance": []
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
"nbformat": 4,
"nbformat_minor": 0
}
Loading

0 comments on commit 22a79db

Please sign in to comment.