forked from Trusted-AI/AIF360
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated 12 example notebooks to run on Colab. Fixes Trusted-AI#385
Signed-off-by: Clayton O'Dell <codell2@nd.edu>
- Loading branch information
Showing
12 changed files
with
12,046 additions
and
10,796 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,183 +1,202 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false | ||
} | ||
}, | ||
"outputs": [ | ||
"cells": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"iteration: 1, error: 0.229, fairness violation: 0.05428400000000001, violated group size: 0.249\n", | ||
"iteration: 2, error: 0.3645, fairness violation: 0.027142000000000006, violated group size: 0.249\n", | ||
"iteration: 3, error: 0.4096666666666666, fairness violation: 0.01809466666666667, violated group size: 0.251\n", | ||
"iteration: 4, error: 0.43225, fairness violation: 0.013571000000000003, violated group size: 0.249\n", | ||
"iteration: 5, error: 0.44580000000000014, fairness violation: 0.0108568, violated group size: 0.251\n", | ||
"iteration: 6, error: 0.4548333333333334, fairness violation: 0.009047333333333338, violated group size: 0.251\n", | ||
"iteration: 7, error: 0.46128571428571435, fairness violation: 0.007754857142857144, violated group size: 0.251\n", | ||
"iteration: 8, error: 0.466125, fairness violation: 0.006785500000000003, violated group size: 0.251\n", | ||
"iteration: 9, error: 0.469888888888889, fairness violation: 0.006031555555555558, violated group size: 0.249\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%matplotlib inline\n", | ||
"import warnings\n", | ||
"warnings.filterwarnings(\"ignore\")\n", | ||
"import sys\n", | ||
"sys.path.append(\"../\")\n", | ||
"from aif360.algorithms.inprocessing import GerryFairClassifier\n", | ||
"from aif360.algorithms.inprocessing.gerryfair.clean import array_to_tuple\n", | ||
"from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult\n", | ||
"from sklearn import svm\n", | ||
"from sklearn import tree\n", | ||
"from sklearn.kernel_ridge import KernelRidge\n", | ||
"from sklearn import linear_model\n", | ||
"from aif360.metrics import BinaryLabelDatasetMetric\n", | ||
"from IPython.display import Image\n", | ||
"import pickle\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"# load data set\n", | ||
"data_set = load_preproc_data_adult(sub_samp=1000, balance=True)\n", | ||
"max_iterations = 10\n", | ||
"C = 100\n", | ||
"print_flag = True\n", | ||
"gamma = .005\n", | ||
"\n", | ||
"fair_model = GerryFairClassifier(C=C, printflag=print_flag, gamma=gamma, fairness_def='FP',\n", | ||
" max_iters=max_iterations, heatmapflag=False)\n", | ||
"# fit method\n", | ||
"fair_model.fit(data_set, early_termination=True)\n", | ||
"\n", | ||
"# predict method. If threshold in (0, 1) produces binary predictions\n", | ||
"dataset_yhat = fair_model.predict(data_set, threshold=False)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false | ||
} | ||
}, | ||
"outputs": [ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Trusted-AI/AIF360/blob/main/examples/demo_short_gerryfair_test.ipynb)" | ||
], | ||
"metadata": { | ||
"id": "qsaaqV7jGtYV" | ||
} | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"0.0060315555555555565\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# auditing \n", | ||
"\n", | ||
"gerry_metric = BinaryLabelDatasetMetric(data_set)\n", | ||
"gamma_disparity = gerry_metric.rich_subgroup(array_to_tuple(dataset_yhat.labels), 'FP')\n", | ||
"print(gamma_disparity)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false | ||
} | ||
}, | ||
"outputs": [ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false | ||
}, | ||
"id": "9Cb8bfPXGUa0", | ||
"outputId": "165bfd3e-1281-4b34-d21d-137182075fab" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"iteration: 1, error: 0.229, fairness violation: 0.05428400000000001, violated group size: 0.249\n", | ||
"iteration: 2, error: 0.3645, fairness violation: 0.027142000000000006, violated group size: 0.249\n", | ||
"iteration: 3, error: 0.4096666666666666, fairness violation: 0.01809466666666667, violated group size: 0.251\n", | ||
"iteration: 4, error: 0.43225, fairness violation: 0.013571000000000003, violated group size: 0.249\n", | ||
"iteration: 5, error: 0.44580000000000014, fairness violation: 0.0108568, violated group size: 0.251\n", | ||
"iteration: 6, error: 0.4548333333333334, fairness violation: 0.009047333333333338, violated group size: 0.251\n", | ||
"iteration: 7, error: 0.46128571428571435, fairness violation: 0.007754857142857144, violated group size: 0.251\n", | ||
"iteration: 8, error: 0.466125, fairness violation: 0.006785500000000003, violated group size: 0.251\n", | ||
"iteration: 9, error: 0.469888888888889, fairness violation: 0.006031555555555558, violated group size: 0.249\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%matplotlib inline\n", | ||
"import warnings\n", | ||
"warnings.filterwarnings(\"ignore\")\n", | ||
"import sys\n", | ||
"sys.path.append(\"../\")\n", | ||
"from aif360.algorithms.inprocessing import GerryFairClassifier\n", | ||
"from aif360.algorithms.inprocessing.gerryfair.clean import array_to_tuple\n", | ||
"from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult\n", | ||
"from sklearn import svm\n", | ||
"from sklearn import tree\n", | ||
"from sklearn.kernel_ridge import KernelRidge\n", | ||
"from sklearn import linear_model\n", | ||
"from aif360.metrics import BinaryLabelDatasetMetric\n", | ||
"from IPython.display import Image\n", | ||
"import pickle\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"# load data set\n", | ||
"data_set = load_preproc_data_adult(sub_samp=1000, balance=True)\n", | ||
"max_iterations = 10\n", | ||
"C = 100\n", | ||
"print_flag = True\n", | ||
"gamma = .005\n", | ||
"\n", | ||
"fair_model = GerryFairClassifier(C=C, printflag=print_flag, gamma=gamma, fairness_def='FP',\n", | ||
" max_iters=max_iterations, heatmapflag=False)\n", | ||
"# fit method\n", | ||
"fair_model.fit(data_set, early_termination=True)\n", | ||
"\n", | ||
"# predict method. If threshold in (0, 1) produces binary predictions\n", | ||
"dataset_yhat = fair_model.predict(data_set, threshold=False)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false | ||
}, | ||
"id": "cCdltbMWGUa2", | ||
"outputId": "0ab1dd98-6cae-45a1-e51c-41ef9b7193df" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"0.0060315555555555565\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# auditing\n", | ||
"\n", | ||
"gerry_metric = BinaryLabelDatasetMetric(data_set)\n", | ||
"gamma_disparity = gerry_metric.rich_subgroup(array_to_tuple(dataset_yhat.labels), 'FP')\n", | ||
"print(gamma_disparity)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Curr Predictor: Linear\n", | ||
"Curr Predictor: SVR\n", | ||
"Curr Predictor: Tree\n", | ||
"Curr Predictor: Kernel\n" | ||
] | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false | ||
}, | ||
"id": "on2jaBgsGUa3", | ||
"outputId": "ed58df2b-2288-491b-bf07-ad5d783baa06" | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Curr Predictor: Linear\n", | ||
"Curr Predictor: SVR\n", | ||
"Curr Predictor: Tree\n", | ||
"Curr Predictor: Kernel\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# set to 10 iterations for fast running of notebook - set >= 1000 when running real experiments\n", | ||
"# tests learning with different hypothesis classes\n", | ||
"pareto_iters = 10\n", | ||
"def multiple_classifiers_pareto(dataset, gamma_list=[0.002, 0.005, 0.01], save_results=False, iters=pareto_iters):\n", | ||
"\n", | ||
" ln_predictor = linear_model.LinearRegression()\n", | ||
" svm_predictor = svm.LinearSVR()\n", | ||
" tree_predictor = tree.DecisionTreeRegressor(max_depth=3)\n", | ||
" kernel_predictor = KernelRidge(alpha=1.0, gamma=1.0, kernel='rbf')\n", | ||
" predictor_dict = {'Linear': {'predictor': ln_predictor, 'iters': iters},\n", | ||
" 'SVR': {'predictor': svm_predictor, 'iters': iters},\n", | ||
" 'Tree': {'predictor': tree_predictor, 'iters': iters},\n", | ||
" 'Kernel': {'predictor': kernel_predictor, 'iters': iters}}\n", | ||
"\n", | ||
" results_dict = {}\n", | ||
"\n", | ||
" for pred in predictor_dict:\n", | ||
" print('Curr Predictor: {}'.format(pred))\n", | ||
" predictor = predictor_dict[pred]['predictor']\n", | ||
" max_iters = predictor_dict[pred]['iters']\n", | ||
" fair_clf = GerryFairClassifier(C=100, printflag=True, gamma=1, predictor=predictor, max_iters=max_iters)\n", | ||
" fair_clf.printflag = False\n", | ||
" fair_clf.max_iters=max_iters\n", | ||
" errors, fp_violations, fn_violations = fair_clf.pareto(dataset, gamma_list)\n", | ||
" results_dict[pred] = {'errors': errors, 'fp_violations': fp_violations, 'fn_violations': fn_violations}\n", | ||
" if save_results:\n", | ||
" pickle.dump(results_dict, open('results_dict_' + str(gamma_list) + '_gammas' + str(gamma_list) + '.pkl', 'wb'))\n", | ||
"\n", | ||
"multiple_classifiers_pareto(data_set)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"pycharm": { | ||
"is_executing": false | ||
}, | ||
"id": "tgGHGbQZGUa3" | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"source": [ | ||
"# set to 10 iterations for fast running of notebook - set >= 1000 when running real experiments\n", | ||
"# tests learning with different hypothesis classes\n", | ||
"pareto_iters = 10\n", | ||
"def multiple_classifiers_pareto(dataset, gamma_list=[0.002, 0.005, 0.01], save_results=False, iters=pareto_iters):\n", | ||
"\n", | ||
" ln_predictor = linear_model.LinearRegression()\n", | ||
" svm_predictor = svm.LinearSVR()\n", | ||
" tree_predictor = tree.DecisionTreeRegressor(max_depth=3)\n", | ||
" kernel_predictor = KernelRidge(alpha=1.0, gamma=1.0, kernel='rbf')\n", | ||
" predictor_dict = {'Linear': {'predictor': ln_predictor, 'iters': iters},\n", | ||
" 'SVR': {'predictor': svm_predictor, 'iters': iters},\n", | ||
" 'Tree': {'predictor': tree_predictor, 'iters': iters},\n", | ||
" 'Kernel': {'predictor': kernel_predictor, 'iters': iters}}\n", | ||
"\n", | ||
" results_dict = {}\n", | ||
"\n", | ||
" for pred in predictor_dict:\n", | ||
" print('Curr Predictor: {}'.format(pred))\n", | ||
" predictor = predictor_dict[pred]['predictor']\n", | ||
" max_iters = predictor_dict[pred]['iters']\n", | ||
" fair_clf = GerryFairClassifier(C=100, printflag=True, gamma=1, predictor=predictor, max_iters=max_iters)\n", | ||
" fair_clf.printflag = False\n", | ||
" fair_clf.max_iters=max_iters\n", | ||
" errors, fp_violations, fn_violations = fair_clf.pareto(dataset, gamma_list)\n", | ||
" results_dict[pred] = {'errors': errors, 'fp_violations': fp_violations, 'fn_violations': fn_violations}\n", | ||
" if save_results:\n", | ||
" pickle.dump(results_dict, open('results_dict_' + str(gamma_list) + '_gammas' + str(gamma_list) + '.pkl', 'wb'))\n", | ||
"\n", | ||
"multiple_classifiers_pareto(data_set)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.9" | ||
}, | ||
"pycharm": { | ||
"is_executing": false | ||
"stem_cell": { | ||
"cell_type": "raw", | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"source": [] | ||
} | ||
}, | ||
"colab": { | ||
"provenance": [] | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.9" | ||
}, | ||
"pycharm": { | ||
"stem_cell": { | ||
"cell_type": "raw", | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"source": [] | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.