-
Notifications
You must be signed in to change notification settings - Fork 299
[ADD] Extra visualization example #189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
nabenabe0928
merged 8 commits into
automl:refactor_development
from
franchuterivera:refactor_development_extradocs
May 7, 2021
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
205e13b
[ADD] Extra visualization example
franchuterivera d50ff32
Update docs/manual.rst
franchuterivera fb370c7
Update docs/manual.rst
franchuterivera e8aa61b
[Fix] missing version
franchuterivera c2a0682
Merge branch 'refactor_development_extradocs' of github.com:franchute…
franchuterivera a2df1bb
Update examples/tabular/40_advanced/example_visualization.py
franchuterivera d6eeae0
[FIX] make docs more clear to the user
franchuterivera 6420849
Merge branch 'refactor_development_extradocs' of github.com:franchute…
franchuterivera File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
"""Version information.""" | ||
|
||
# The following line *must* be the last in the module, exactly as formatted: | ||
__version__ = "0.0.3" | ||
__version__ = "0.1.0" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
""" | ||
======================= | ||
Visualizing the Results | ||
======================= | ||
|
||
Auto-Pytorch uses SMAC to fit individual machine learning algorithms | ||
and then ensembles them together using `Ensemble Selection | ||
<https://www.cs.cornell.edu/~caruana/ctp/ct.papers/caruana.icml04.icdm06long.pdf>`_. | ||
|
||
The following examples shows how to visualize both the performance | ||
of the individual models and their respective ensemble. | ||
|
||
Additionally, as we are compatible with scikit-learn, | ||
we show how to further interact with `Scikit-Learn Inspection | ||
<https://scikit-learn.org/stable/inspection.html>`_ support. | ||
|
||
|
||
""" | ||
import os | ||
import pickle | ||
import tempfile as tmp | ||
import time | ||
import warnings | ||
|
||
# The following variables are not needed for every unix distribution, but are | ||
# highlighted in here to prevent problems with multiprocessing with scikit-learn. | ||
os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir() | ||
os.environ['OMP_NUM_THREADS'] = '1' | ||
os.environ['OPENBLAS_NUM_THREADS'] = '1' | ||
os.environ['MKL_NUM_THREADS'] = '1' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about those environment variables? |
||
|
||
warnings.simplefilter(action='ignore', category=UserWarning) | ||
warnings.simplefilter(action='ignore', category=FutureWarning) | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
import numpy as np | ||
|
||
import pandas as pd | ||
|
||
|
||
import sklearn.datasets | ||
import sklearn.model_selection | ||
from sklearn.inspection import permutation_importance | ||
|
||
from smac.tae import StatusType | ||
|
||
|
||
from autoPyTorch.api.tabular_classification import TabularClassificationTask | ||
from autoPyTorch.metrics import accuracy | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
############################################################################ | ||
# Data Loading | ||
# ============ | ||
|
||
# We will use the iris dataset for this Toy example | ||
seed = 42 | ||
X, y = sklearn.datasets.fetch_openml(data_id=61, return_X_y=True, as_frame=True) | ||
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( | ||
X, | ||
y, | ||
random_state=42, | ||
) | ||
|
||
############################################################################ | ||
# Build and fit a classifier | ||
# ========================== | ||
api = TabularClassificationTask(seed=seed) | ||
api.search( | ||
X_train=X_train, | ||
y_train=y_train, | ||
X_test=X_test.copy(), | ||
y_test=y_test.copy(), | ||
optimize_metric=accuracy.name, | ||
total_walltime_limit=200, | ||
func_eval_time_limit_secs=50 | ||
) | ||
|
||
############################################################################ | ||
# One can also save the model for future inference | ||
# ================================================ | ||
|
||
# For more details on how to deploy a model, please check | ||
# `Scikit-Learn persistence | ||
# <https://scikit-learn.org/stable/modules/model_persistence.html>`_ support. | ||
with open('estimator.pickle', 'wb') as handle: | ||
pickle.dump(api, handle, protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
# Then let us read it back and use it for our analysis | ||
with open('estimator.pickle', 'rb') as handle: | ||
estimator = pickle.load(handle) | ||
|
||
############################################################################ | ||
# Plotting the model performance | ||
# ============================== | ||
|
||
# We will plot the search incumbent through time. | ||
|
||
# Collect the performance of individual machine learning algorithms | ||
# found by SMAC | ||
individual_performances = [] | ||
for run_key, run_value in estimator.run_history.data.items(): | ||
if run_value.status != StatusType.SUCCESS: | ||
# Ignore crashed runs | ||
continue | ||
individual_performances.append({ | ||
'Timestamp': pd.Timestamp( | ||
time.strftime( | ||
'%Y-%m-%d %H:%M:%S', | ||
time.localtime(run_value.endtime) | ||
) | ||
), | ||
'single_best_optimization_accuracy': accuracy._optimum - run_value.cost, | ||
'single_best_test_accuracy': np.nan if run_value.additional_info is None else | ||
accuracy._optimum - run_value.additional_info['test_loss'], | ||
}) | ||
individual_performance_frame = pd.DataFrame(individual_performances) | ||
|
||
# Collect the performance of the ensemble through time | ||
# This ensemble is built from the machine learning algorithms | ||
# found by SMAC | ||
ensemble_performance_frame = pd.DataFrame(estimator.ensemble_performance_history) | ||
|
||
# As we are tracking the incumbent, we are interested in the cummax() performance | ||
ensemble_performance_frame['ensemble_optimization_accuracy'] = ensemble_performance_frame[ | ||
'train_accuracy' | ||
].cummax() | ||
ensemble_performance_frame['ensemble_test_accuracy'] = ensemble_performance_frame[ | ||
'test_accuracy' | ||
].cummax() | ||
ensemble_performance_frame.drop(columns=['test_accuracy', 'train_accuracy'], inplace=True) | ||
individual_performance_frame['single_best_optimization_accuracy'] = individual_performance_frame[ | ||
'single_best_optimization_accuracy' | ||
].cummax() | ||
individual_performance_frame['single_best_test_accuracy'] = individual_performance_frame[ | ||
'single_best_test_accuracy' | ||
].cummax() | ||
|
||
pd.merge( | ||
ensemble_performance_frame, | ||
individual_performance_frame, | ||
on="Timestamp", how='outer' | ||
).sort_values('Timestamp').fillna(method='ffill').plot( | ||
x='Timestamp', | ||
kind='line', | ||
legend=True, | ||
title='Auto-PyTorch accuracy over time', | ||
grid=True, | ||
) | ||
plt.show() | ||
|
||
# We then can understand the importance of each input feature using | ||
# a permutation importance analysis. This is done as a proof of concept, to | ||
# showcase that we can leverage of scikit-learn API. | ||
result = permutation_importance(estimator, X_train, y_train, n_repeats=5, | ||
scoring='accuracy', | ||
random_state=seed) | ||
sorted_idx = result.importances_mean.argsort() | ||
|
||
fig, ax = plt.subplots() | ||
ax.boxplot(result.importances[sorted_idx].T, | ||
vert=False, labels=X_test.columns[sorted_idx]) | ||
ax.set_title("Permutation Importances (Train set)") | ||
fig.tight_layout() | ||
plt.show() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.