Skip to content

Commit 6b2b6cc

Browse files
author
Github Actions
committed
Francisco Rivera Valverde: [ADD] Extra visualization example (#189)
1 parent a993f0f commit 6b2b6cc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1119
-150
lines changed

refactor_development/.buildinfo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Sphinx build info version 1
22
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
3-
config: ad8f043e67f31d900fdc9a0563a66d13
3+
config: da1291ead51b7998a2311fe24055da4c
44
tags: 645f666f9bcd5a90fca523b33c5a78b7
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""
2+
=======================
3+
Visualizing the Results
4+
=======================
5+
6+
Auto-Pytorch uses SMAC to fit individual machine learning algorithms
7+
and then ensembles them together using `Ensemble Selection
8+
<https://www.cs.cornell.edu/~caruana/ctp/ct.papers/caruana.icml04.icdm06long.pdf>`_.
9+
10+
The following examples shows how to visualize both the performance
11+
of the individual models and their respective ensemble.
12+
13+
Additionally, as we are compatible with scikit-learn,
14+
we show how to further interact with `Scikit-Learn Inspection
15+
<https://scikit-learn.org/stable/inspection.html>`_ support.
16+
17+
18+
"""
19+
import os
20+
import pickle
21+
import tempfile as tmp
22+
import time
23+
import warnings
24+
25+
# The following variables are not needed for every unix distribution, but are
26+
# highlighted in here to prevent problems with multiprocessing with scikit-learn.
27+
os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()
28+
os.environ['OMP_NUM_THREADS'] = '1'
29+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
30+
os.environ['MKL_NUM_THREADS'] = '1'
31+
32+
warnings.simplefilter(action='ignore', category=UserWarning)
33+
warnings.simplefilter(action='ignore', category=FutureWarning)
34+
35+
import matplotlib.pyplot as plt
36+
37+
import numpy as np
38+
39+
import pandas as pd
40+
41+
42+
import sklearn.datasets
43+
import sklearn.model_selection
44+
from sklearn.inspection import permutation_importance
45+
46+
from smac.tae import StatusType
47+
48+
49+
from autoPyTorch.api.tabular_classification import TabularClassificationTask
50+
from autoPyTorch.metrics import accuracy
51+
52+
53+
if __name__ == '__main__':
54+
55+
############################################################################
56+
# Data Loading
57+
# ============
58+
59+
# We will use the iris dataset for this Toy example
60+
seed = 42
61+
X, y = sklearn.datasets.fetch_openml(data_id=61, return_X_y=True, as_frame=True)
62+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
63+
X,
64+
y,
65+
random_state=42,
66+
)
67+
68+
############################################################################
69+
# Build and fit a classifier
70+
# ==========================
71+
api = TabularClassificationTask(seed=seed)
72+
api.search(
73+
X_train=X_train,
74+
y_train=y_train,
75+
X_test=X_test.copy(),
76+
y_test=y_test.copy(),
77+
optimize_metric=accuracy.name,
78+
total_walltime_limit=200,
79+
func_eval_time_limit_secs=50
80+
)
81+
82+
############################################################################
83+
# One can also save the model for future inference
84+
# ================================================
85+
86+
# For more details on how to deploy a model, please check
87+
# `Scikit-Learn persistence
88+
# <https://scikit-learn.org/stable/modules/model_persistence.html>`_ support.
89+
with open('estimator.pickle', 'wb') as handle:
90+
pickle.dump(api, handle, protocol=pickle.HIGHEST_PROTOCOL)
91+
92+
# Then let us read it back and use it for our analysis
93+
with open('estimator.pickle', 'rb') as handle:
94+
estimator = pickle.load(handle)
95+
96+
############################################################################
97+
# Plotting the model performance
98+
# ==============================
99+
100+
# We will plot the search incumbent through time.
101+
102+
# Collect the performance of individual machine learning algorithms
103+
# found by SMAC
104+
individual_performances = []
105+
for run_key, run_value in estimator.run_history.data.items():
106+
if run_value.status != StatusType.SUCCESS:
107+
# Ignore crashed runs
108+
continue
109+
individual_performances.append({
110+
'Timestamp': pd.Timestamp(
111+
time.strftime(
112+
'%Y-%m-%d %H:%M:%S',
113+
time.localtime(run_value.endtime)
114+
)
115+
),
116+
'single_best_optimization_accuracy': accuracy._optimum - run_value.cost,
117+
'single_best_test_accuracy': np.nan if run_value.additional_info is None else
118+
accuracy._optimum - run_value.additional_info['test_loss'],
119+
})
120+
individual_performance_frame = pd.DataFrame(individual_performances)
121+
122+
# Collect the performance of the ensemble through time
123+
# This ensemble is built from the machine learning algorithms
124+
# found by SMAC
125+
ensemble_performance_frame = pd.DataFrame(estimator.ensemble_performance_history)
126+
127+
# As we are tracking the incumbent, we are interested in the cummax() performance
128+
ensemble_performance_frame['ensemble_optimization_accuracy'] = ensemble_performance_frame[
129+
'train_accuracy'
130+
].cummax()
131+
ensemble_performance_frame['ensemble_test_accuracy'] = ensemble_performance_frame[
132+
'test_accuracy'
133+
].cummax()
134+
ensemble_performance_frame.drop(columns=['test_accuracy', 'train_accuracy'], inplace=True)
135+
individual_performance_frame['single_best_optimization_accuracy'] = individual_performance_frame[
136+
'single_best_optimization_accuracy'
137+
].cummax()
138+
individual_performance_frame['single_best_test_accuracy'] = individual_performance_frame[
139+
'single_best_test_accuracy'
140+
].cummax()
141+
142+
pd.merge(
143+
ensemble_performance_frame,
144+
individual_performance_frame,
145+
on="Timestamp", how='outer'
146+
).sort_values('Timestamp').fillna(method='ffill').plot(
147+
x='Timestamp',
148+
kind='line',
149+
legend=True,
150+
title='Auto-PyTorch accuracy over time',
151+
grid=True,
152+
)
153+
plt.show()
154+
155+
# We then can understand the importance of each input feature using
156+
# a permutation importance analysis. This is done as a proof of concept, to
157+
# showcase that we can leverage of scikit-learn API.
158+
result = permutation_importance(estimator, X_train, y_train, n_repeats=5,
159+
scoring='accuracy',
160+
random_state=seed)
161+
sorted_idx = result.importances_mean.argsort()
162+
163+
fig, ax = plt.subplots()
164+
ax.boxplot(result.importances[sorted_idx].T,
165+
vert=False, labels=X_test.columns[sorted_idx])
166+
ax.set_title("Permutation Importances (Train set)")
167+
fig.tight_layout()
168+
plt.show()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [],
10+
"source": [
11+
"%matplotlib inline"
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"\n# Visualizing the Results\n\nAuto-Pytorch uses SMAC to fit individual machine learning algorithms\nand then ensembles them together using `Ensemble Selection\n<https://www.cs.cornell.edu/~caruana/ctp/ct.papers/caruana.icml04.icdm06long.pdf>`_.\n\nThe following examples shows how to visualize both the performance\nof the individual models and their respective ensemble.\n\nAdditionally, as we are compatible with scikit-learn,\nwe show how to further interact with `Scikit-Learn Inspection\n<https://scikit-learn.org/stable/inspection.html>`_ support.\n"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {
25+
"collapsed": false
26+
},
27+
"outputs": [],
28+
"source": [
29+
"import os\nimport pickle\nimport tempfile as tmp\nimport time\nimport warnings\n\n# The following variables are not needed for every unix distribution, but are\n# highlighted in here to prevent problems with multiprocessing with scikit-learn.\nos.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()\nos.environ['OMP_NUM_THREADS'] = '1'\nos.environ['OPENBLAS_NUM_THREADS'] = '1'\nos.environ['MKL_NUM_THREADS'] = '1'\n\nwarnings.simplefilter(action='ignore', category=UserWarning)\nwarnings.simplefilter(action='ignore', category=FutureWarning)\n\nimport matplotlib.pyplot as plt\n\nimport numpy as np\n\nimport pandas as pd\n\n\nimport sklearn.datasets\nimport sklearn.model_selection\nfrom sklearn.inspection import permutation_importance\n\nfrom smac.tae import StatusType\n\n\nfrom autoPyTorch.api.tabular_classification import TabularClassificationTask\nfrom autoPyTorch.metrics import accuracy\n\n\nif __name__ == '__main__':\n\n ############################################################################\n # Data Loading\n # ============\n\n # We will use the iris dataset for this Toy example\n seed = 42\n X, y = sklearn.datasets.fetch_openml(data_id=61, return_X_y=True, as_frame=True)\n X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(\n X,\n y,\n random_state=42,\n )\n\n ############################################################################\n # Build and fit a classifier\n # ==========================\n api = TabularClassificationTask(seed=seed)\n api.search(\n X_train=X_train,\n y_train=y_train,\n X_test=X_test.copy(),\n y_test=y_test.copy(),\n optimize_metric=accuracy.name,\n total_walltime_limit=200,\n func_eval_time_limit_secs=50\n )\n\n ############################################################################\n # One can also save the model for future inference\n # ================================================\n\n # For more details on how to deploy a model, please check\n # `Scikit-Learn persistence\n # <https://scikit-learn.org/stable/modules/model_persistence.html>`_ support.\n with open('estimator.pickle', 'wb') as handle:\n pickle.dump(api, handle, protocol=pickle.HIGHEST_PROTOCOL)\n\n # Then let us read it back and use it for our analysis\n with open('estimator.pickle', 'rb') as handle:\n estimator = pickle.load(handle)\n\n ############################################################################\n # Plotting the model performance\n # ==============================\n\n # We will plot the search incumbent through time.\n\n # Collect the performance of individual machine learning algorithms\n # found by SMAC\n individual_performances = []\n for run_key, run_value in estimator.run_history.data.items():\n if run_value.status != StatusType.SUCCESS:\n # Ignore crashed runs\n continue\n individual_performances.append({\n 'Timestamp': pd.Timestamp(\n time.strftime(\n '%Y-%m-%d %H:%M:%S',\n time.localtime(run_value.endtime)\n )\n ),\n 'single_best_optimization_accuracy': accuracy._optimum - run_value.cost,\n 'single_best_test_accuracy': np.nan if run_value.additional_info is None else\n accuracy._optimum - run_value.additional_info['test_loss'],\n })\n individual_performance_frame = pd.DataFrame(individual_performances)\n\n # Collect the performance of the ensemble through time\n # This ensemble is built from the machine learning algorithms\n # found by SMAC\n ensemble_performance_frame = pd.DataFrame(estimator.ensemble_performance_history)\n\n # As we are tracking the incumbent, we are interested in the cummax() performance\n ensemble_performance_frame['ensemble_optimization_accuracy'] = ensemble_performance_frame[\n 'train_accuracy'\n ].cummax()\n ensemble_performance_frame['ensemble_test_accuracy'] = ensemble_performance_frame[\n 'test_accuracy'\n ].cummax()\n ensemble_performance_frame.drop(columns=['test_accuracy', 'train_accuracy'], inplace=True)\n individual_performance_frame['single_best_optimization_accuracy'] = individual_performance_frame[\n 'single_best_optimization_accuracy'\n ].cummax()\n individual_performance_frame['single_best_test_accuracy'] = individual_performance_frame[\n 'single_best_test_accuracy'\n ].cummax()\n\n pd.merge(\n ensemble_performance_frame,\n individual_performance_frame,\n on=\"Timestamp\", how='outer'\n ).sort_values('Timestamp').fillna(method='ffill').plot(\n x='Timestamp',\n kind='line',\n legend=True,\n title='Auto-PyTorch accuracy over time',\n grid=True,\n )\n plt.show()\n\n # We then can understand the importance of each input feature using\n # a permutation importance analysis. This is done as a proof of concept, to\n # showcase that we can leverage of scikit-learn API.\n result = permutation_importance(estimator, X_train, y_train, n_repeats=5,\n scoring='accuracy',\n random_state=seed)\n sorted_idx = result.importances_mean.argsort()\n\n fig, ax = plt.subplots()\n ax.boxplot(result.importances[sorted_idx].T,\n vert=False, labels=X_test.columns[sorted_idx])\n ax.set_title(\"Permutation Importances (Train set)\")\n fig.tight_layout()\n plt.show()"
30+
]
31+
}
32+
],
33+
"metadata": {
34+
"kernelspec": {
35+
"display_name": "Python 3",
36+
"language": "python",
37+
"name": "python3"
38+
},
39+
"language_info": {
40+
"codemirror_mode": {
41+
"name": "ipython",
42+
"version": 3
43+
},
44+
"file_extension": ".py",
45+
"mimetype": "text/x-python",
46+
"name": "python",
47+
"nbconvert_exporter": "python",
48+
"pygments_lexer": "ipython3",
49+
"version": "3.8.9"
50+
}
51+
},
52+
"nbformat": 4,
53+
"nbformat_minor": 0
54+
}
Loading
Loading
Loading

refactor_development/_modules/autoPyTorch/api/tabular_classification.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<head>
55
<meta charset="utf-8" />
66
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
7-
<title>autoPyTorch.api.tabular_classification &#8212; AutoPyTorch 0.0.3 documentation</title>
7+
<title>autoPyTorch.api.tabular_classification &#8212; AutoPyTorch 0.1.0 documentation</title>
88
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
99
<link rel="stylesheet" href="../../../_static/bootstrap-sphinx.css" type="text/css" />
1010
<link rel="stylesheet" type="text/css" href="../../../_static/gallery.css" />
@@ -54,7 +54,7 @@
5454
</button>
5555
<a class="navbar-brand" href="../../../index.html">
5656
Auto-PyTorch</a>
57-
<span class="navbar-text navbar-version pull-left"><b>0.0.3</b></span>
57+
<span class="navbar-text navbar-version pull-left"><b>0.1.0</b></span>
5858
</div>
5959

6060
<div class="collapse navbar-collapse nav-collapse">

refactor_development/_modules/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<head>
55
<meta charset="utf-8" />
66
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
7-
<title>Overview: module code &#8212; AutoPyTorch 0.0.3 documentation</title>
7+
<title>Overview: module code &#8212; AutoPyTorch 0.1.0 documentation</title>
88
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
99
<link rel="stylesheet" href="../_static/bootstrap-sphinx.css" type="text/css" />
1010
<link rel="stylesheet" type="text/css" href="../_static/gallery.css" />
@@ -54,7 +54,7 @@
5454
</button>
5555
<a class="navbar-brand" href="../index.html">
5656
Auto-PyTorch</a>
57-
<span class="navbar-text navbar-version pull-left"><b>0.0.3</b></span>
57+
<span class="navbar-text navbar-version pull-left"><b>0.1.0</b></span>
5858
</div>
5959

6060
<div class="collapse navbar-collapse nav-collapse">

0 commit comments

Comments
 (0)