Skip to content

Commit 94df1e3

Browse files
[ADD] Extra visualization example (#189)
* [ADD] Extra visualization example * Update docs/manual.rst Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * Update docs/manual.rst Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * [Fix] missing version * Update examples/tabular/40_advanced/example_visualization.py Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com> * [FIX] make docs more clear to the user Co-authored-by: Ravin Kohli <13005107+ravinkohli@users.noreply.github.com>
1 parent 9f4b855 commit 94df1e3

File tree

5 files changed

+215
-4
lines changed

5 files changed

+215
-4
lines changed

autoPyTorch/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Version information."""
22

33
# The following line *must* be the last in the module, exactly as formatted:
4-
__version__ = "0.0.3"
4+
__version__ = "0.1.0"

autoPyTorch/api/base_task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
576576
assert self._dask_client is not None
577577

578578
self._logger.info("Starting to create traditional classifier predictions.")
579+
starttime = time.time()
579580

580581
# Initialise run history for the traditional classifiers
581582
run_history = RunHistory()
@@ -649,6 +650,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
649650
origin = additional_info['configuration_origin']
650651
run_history.add(config=configuration, cost=cost,
651652
time=runtime, status=status, seed=self.seed,
653+
starttime=starttime, endtime=starttime + runtime,
652654
origin=origin)
653655
else:
654656
if additional_info.get('exitcode') == -6:

docs/manual.rst

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,45 @@
66
Manual
77
======
88

9-
TODO
9+
This manual shows how to get started with Auto-PyTorch. We recommend going over the examples first.
10+
There are additional recommendations on how to interact with the API, further below in this manual.
11+
However, you are welcome to contribute to this documentation by making a Pull-Request.
12+
13+
The searching starts by calling `search()` function of each supported task.
14+
Currently, we are supporting Tabular classification and Tabular Regression.
15+
We expand the support to image processing tasks in the future.
16+
17+
Examples
18+
========
19+
* `Classification <examples/tabular/20_basics/example_tabular_classification.html>`_
20+
* `Regression <examples/tabular/20_basics/example_tabular_regression.html>`_
21+
* `Customizing the search space <examples/tabular/40_advanced/example_custom_configuration_space.html>`_
22+
* `Changing the resampling strategy <examples/tabular/40_advanced/example_resampling_strategy.html>`_
23+
* `Visualizing the results <examples/tabular/40_advanced/example_visualization.html>`_
24+
25+
Resource Allocation
26+
===================
27+
28+
Auto-PyTorch allows to control the maximum allowed resident set size memory (RSS) that an estimator can use. By providing the `memory_limit` argument to the `search()` method, one can make sure that neither the individual machine learning models fitted by SMAC nor the final ensemble consume more than `memory_limit` megabytes.
29+
30+
Additionally, one can control the allocated time to search for a model via the argument `total_walltime_limit` to the `search()` method. This argument controls the total time SMAC can use to search for new configurations. The more time is allocated, the better the final estimator will be.
31+
32+
Ensemble Building Process
33+
=========================
34+
35+
Auto-PyTorch uses ensemble selection by `Caruana et al. (2004) <https://dl.acm.org/doi/pdf/10.1145/1015330.1015432>`_
36+
to build an ensemble based on the models’ prediction for the validation set. The following hyperparameters control how the ensemble is constructed:
37+
38+
* ``ensemble_size`` determines the maximal size of the ensemble. If it is set to zero, no ensemble will be constructed.
39+
* ``ensemble_nbest`` allows the user to directly specify the number of models considered for the ensemble. When an integer is provided for this hyperparameter, the final ensemble chooses each predictor from only the best n models. If a float between 0.0 and 1.0 is provided, ``ensemble_nbest`` would be interpreted as a fraction suggesting the percentage of models to use in the ensemble building process (namely, if ensemble_nbest is a float, library pruning is implemented as described in `Caruana et al. (2006) <https://dl.acm.org/doi/10.1109/ICDM.2006.76>`_). For example, if 10 candidates are available for the ensemble building process and the hyper-parameter is `ensemble_nbest==0.7``, we build an ensemble by taking the best 7 models among the original 10 candidate models.
40+
* ``max_models_on_disc`` defines the maximum number of models that are kept on the disc, as a mechanism to control the amount of disc space consumed by Auto-PyTorch. Throughout the automl process, different individual models are optimized, and their predictions (and other metadata) are stored on disc. The user can set the upper bound on how many models are acceptable to keep on disc, yet this variable takes priority in the definition of the number of models used by the ensemble builder (that is, the minimum of ``ensemble_size``, ``ensemble_nbest`` and ``max_models_on_disc`` determines the maximal amount of models used in the ensemble). If set to None, this feature is disabled.
41+
42+
Inspecting the results
43+
======================
44+
45+
Auto-PyTorch allows users to inspect the training results and statistics. The following example shows how different statistics can be printed for the inspection.
46+
47+
>>> from autoPyTorch.api.tabular_classification import TabularClassificationTask
48+
>>> automl = TabularClassificationTask()
49+
>>> automl.fit(X_train, y_train)
50+
>>> automl.show_models()

docs/releases.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212
Releases
1313
========
1414

15-
Version 0.0.3
15+
Version 0.1.0
1616
==============
17-
TODO
17+
[refactor] Initial version of the new scikit-learn compatible API.
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""
2+
=======================
3+
Visualizing the Results
4+
=======================
5+
6+
Auto-Pytorch uses SMAC to fit individual machine learning algorithms
7+
and then ensembles them together using `Ensemble Selection
8+
<https://www.cs.cornell.edu/~caruana/ctp/ct.papers/caruana.icml04.icdm06long.pdf>`_.
9+
10+
The following examples shows how to visualize both the performance
11+
of the individual models and their respective ensemble.
12+
13+
Additionally, as we are compatible with scikit-learn,
14+
we show how to further interact with `Scikit-Learn Inspection
15+
<https://scikit-learn.org/stable/inspection.html>`_ support.
16+
17+
18+
"""
19+
import os
20+
import pickle
21+
import tempfile as tmp
22+
import time
23+
import warnings
24+
25+
# The following variables are not needed for every unix distribution, but are
26+
# highlighted in here to prevent problems with multiprocessing with scikit-learn.
27+
os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()
28+
os.environ['OMP_NUM_THREADS'] = '1'
29+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
30+
os.environ['MKL_NUM_THREADS'] = '1'
31+
32+
warnings.simplefilter(action='ignore', category=UserWarning)
33+
warnings.simplefilter(action='ignore', category=FutureWarning)
34+
35+
import matplotlib.pyplot as plt
36+
37+
import numpy as np
38+
39+
import pandas as pd
40+
41+
42+
import sklearn.datasets
43+
import sklearn.model_selection
44+
from sklearn.inspection import permutation_importance
45+
46+
from smac.tae import StatusType
47+
48+
49+
from autoPyTorch.api.tabular_classification import TabularClassificationTask
50+
from autoPyTorch.metrics import accuracy
51+
52+
53+
if __name__ == '__main__':
54+
55+
############################################################################
56+
# Data Loading
57+
# ============
58+
59+
# We will use the iris dataset for this Toy example
60+
seed = 42
61+
X, y = sklearn.datasets.fetch_openml(data_id=61, return_X_y=True, as_frame=True)
62+
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
63+
X,
64+
y,
65+
random_state=42,
66+
)
67+
68+
############################################################################
69+
# Build and fit a classifier
70+
# ==========================
71+
api = TabularClassificationTask(seed=seed)
72+
api.search(
73+
X_train=X_train,
74+
y_train=y_train,
75+
X_test=X_test.copy(),
76+
y_test=y_test.copy(),
77+
optimize_metric=accuracy.name,
78+
total_walltime_limit=200,
79+
func_eval_time_limit_secs=50
80+
)
81+
82+
############################################################################
83+
# One can also save the model for future inference
84+
# ================================================
85+
86+
# For more details on how to deploy a model, please check
87+
# `Scikit-Learn persistence
88+
# <https://scikit-learn.org/stable/modules/model_persistence.html>`_ support.
89+
with open('estimator.pickle', 'wb') as handle:
90+
pickle.dump(api, handle, protocol=pickle.HIGHEST_PROTOCOL)
91+
92+
# Then let us read it back and use it for our analysis
93+
with open('estimator.pickle', 'rb') as handle:
94+
estimator = pickle.load(handle)
95+
96+
############################################################################
97+
# Plotting the model performance
98+
# ==============================
99+
100+
# We will plot the search incumbent through time.
101+
102+
# Collect the performance of individual machine learning algorithms
103+
# found by SMAC
104+
individual_performances = []
105+
for run_key, run_value in estimator.run_history.data.items():
106+
if run_value.status != StatusType.SUCCESS:
107+
# Ignore crashed runs
108+
continue
109+
individual_performances.append({
110+
'Timestamp': pd.Timestamp(
111+
time.strftime(
112+
'%Y-%m-%d %H:%M:%S',
113+
time.localtime(run_value.endtime)
114+
)
115+
),
116+
'single_best_optimization_accuracy': accuracy._optimum - run_value.cost,
117+
'single_best_test_accuracy': np.nan if run_value.additional_info is None else
118+
accuracy._optimum - run_value.additional_info['test_loss'],
119+
})
120+
individual_performance_frame = pd.DataFrame(individual_performances)
121+
122+
# Collect the performance of the ensemble through time
123+
# This ensemble is built from the machine learning algorithms
124+
# found by SMAC
125+
ensemble_performance_frame = pd.DataFrame(estimator.ensemble_performance_history)
126+
127+
# As we are tracking the incumbent, we are interested in the cummax() performance
128+
ensemble_performance_frame['ensemble_optimization_accuracy'] = ensemble_performance_frame[
129+
'train_accuracy'
130+
].cummax()
131+
ensemble_performance_frame['ensemble_test_accuracy'] = ensemble_performance_frame[
132+
'test_accuracy'
133+
].cummax()
134+
ensemble_performance_frame.drop(columns=['test_accuracy', 'train_accuracy'], inplace=True)
135+
individual_performance_frame['single_best_optimization_accuracy'] = individual_performance_frame[
136+
'single_best_optimization_accuracy'
137+
].cummax()
138+
individual_performance_frame['single_best_test_accuracy'] = individual_performance_frame[
139+
'single_best_test_accuracy'
140+
].cummax()
141+
142+
pd.merge(
143+
ensemble_performance_frame,
144+
individual_performance_frame,
145+
on="Timestamp", how='outer'
146+
).sort_values('Timestamp').fillna(method='ffill').plot(
147+
x='Timestamp',
148+
kind='line',
149+
legend=True,
150+
title='Auto-PyTorch accuracy over time',
151+
grid=True,
152+
)
153+
plt.show()
154+
155+
# We then can understand the importance of each input feature using
156+
# a permutation importance analysis. This is done as a proof of concept, to
157+
# showcase that we can leverage of scikit-learn API.
158+
result = permutation_importance(estimator, X_train, y_train, n_repeats=5,
159+
scoring='accuracy',
160+
random_state=seed)
161+
sorted_idx = result.importances_mean.argsort()
162+
163+
fig, ax = plt.subplots()
164+
ax.boxplot(result.importances[sorted_idx].T,
165+
vert=False, labels=X_test.columns[sorted_idx])
166+
ax.set_title("Permutation Importances (Train set)")
167+
fig.tight_layout()
168+
plt.show()

0 commit comments

Comments
 (0)