Skip to content

Commit 4f73391

Browse files
sagar-kaushikeddiebergman
authored andcommitted
Changes show_models() function to return a dictionary of models in ensemble (#1321)
* Changed show_models() function to return a dictionary of models in the ensemble instead of a string
1 parent 11119b8 commit 4f73391

16 files changed

+403
-32
lines changed

autosklearn/automl.py

Lines changed: 143 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,16 +1836,151 @@ def get_models_with_weights(self):
18361836

18371837
return self.ensemble_.get_models_with_weights(self.models_)
18381838

1839-
def show_models(self):
1840-
models_with_weights = self.get_models_with_weights()
1839+
def show_models(self) -> Dict[int, Any]:
1840+
""" Returns a dictionary containing dictionaries of ensemble models.
18411841
1842-
with io.StringIO() as sio:
1843-
sio.write("[")
1844-
for weight, model in models_with_weights:
1845-
sio.write("(%f, %s),\n" % (weight, model))
1846-
sio.write("]")
1842+
Each model in the ensemble can be accessed by giving its ``model_id`` as key.
18471843
1848-
return sio.getvalue()
1844+
A model dictionary contains the following:
1845+
1846+
* ``"model_id"`` - The id given to a model by ``autosklearn``.
1847+
* ``"rank"`` - The rank of the model based on it's ``"cost"``.
1848+
* ``"cost"`` - The loss of the model on the validation set.
1849+
* ``"ensemble_weight"`` - The weight given to the model in the ensemble.
1850+
* ``"voting_model"`` - The ``cv_voting_ensemble`` model (for 'cv' resampling).
1851+
* ``"estimators"`` - List of models (dicts) in ``cv_voting_ensemble`` (for 'cv' resampling).
1852+
* ``"data_preprocessor"`` - The preprocessor used on the data.
1853+
* ``"balancing"`` - The balancing used on the data (for classification).
1854+
* ``"feature_preprocessor"`` - The preprocessor for features types.
1855+
* ``"classifier"`` or ``"regressor"`` - The autosklearn wrapped classifier or regressor.
1856+
* ``"sklearn_classifier"`` or ``"sklearn_regressor"`` - The sklearn classifier or regressor.
1857+
1858+
**Example**
1859+
1860+
.. code-block:: python
1861+
1862+
import sklearn.datasets
1863+
import sklearn.metrics
1864+
import autosklearn.regression
1865+
1866+
X, y = sklearn.datasets.load_diabetes(return_X_y=True)
1867+
1868+
automl = autosklearn.regression.AutoSklearnRegressor(
1869+
time_left_for_this_task=120
1870+
)
1871+
automl.fit(X_train, y_train, dataset_name='diabetes')
1872+
1873+
ensemble_dict = automl.show_models()
1874+
print(ensemble_dict)
1875+
1876+
Output:
1877+
1878+
.. code-block:: text
1879+
1880+
{
1881+
25: {'model_id': 25.0,
1882+
'rank': 1,
1883+
'cost': 0.43667876507897496,
1884+
'ensemble_weight': 0.38,
1885+
'data_preprocessor': <autosklearn.pipeline.components.data_preprocessing....>,
1886+
'feature_preprocessor': <autosklearn.pipeline.components....>,
1887+
'regressor': <autosklearn.pipeline.components.regression....>,
1888+
'sklearn_regressor': SGDRegressor(alpha=0.0006517033225329654,...)
1889+
},
1890+
6: {'model_id': 6.0,
1891+
'rank': 2,
1892+
'cost': 0.4550418898836528,
1893+
'ensemble_weight': 0.3,
1894+
'data_preprocessor': <autosklearn.pipeline.components.data_preprocessing....>,
1895+
'feature_preprocessor': <autosklearn.pipeline.components....>,
1896+
'regressor': <autosklearn.pipeline.components.regression....>,
1897+
'sklearn_regressor': ARDRegression(alpha_1=0.0003701926442639788,...)
1898+
}...
1899+
}
1900+
1901+
Returns
1902+
-------
1903+
Dict(int, Any) : dictionary of length = number of models in the ensemble
1904+
A dictionary of models in the ensemble, where ``model_id`` is the key.
1905+
1906+
"""
1907+
1908+
ensemble_dict = {}
1909+
1910+
def has_key(rv, key):
1911+
return rv.additional_info and key in rv.additional_info
1912+
1913+
table_dict = {}
1914+
for rkey, rval in self.runhistory_.data.items():
1915+
if has_key(rval, 'num_run'):
1916+
model_id = rval.additional_info['num_run']
1917+
table_dict[model_id] = {
1918+
'model_id': model_id,
1919+
'cost': rval.cost
1920+
}
1921+
1922+
# Checking if the dictionary is empty
1923+
if not table_dict:
1924+
raise RuntimeError('No model found. Try increasing \'time_left_for_this_task\'.')
1925+
1926+
for i, weight in enumerate(self.ensemble_.weights_):
1927+
(_, model_id, _) = self.ensemble_.identifiers_[i]
1928+
table_dict[model_id]['ensemble_weight'] = weight
1929+
1930+
table = pd.DataFrame.from_dict(table_dict, orient='index')
1931+
table.sort_values(by='cost', inplace=True)
1932+
1933+
# Checking which resampling strategy is chosen and selecting the appropriate models
1934+
is_cv = (self._resampling_strategy == "cv")
1935+
models = self.cv_models_ if is_cv else self.models_
1936+
1937+
rank = 1 # Initializing rank for the first model
1938+
for (_, model_id, _), model in models.items():
1939+
model_dict = {} # Declaring model dictionary
1940+
1941+
# Inserting model_id, rank, cost and ensemble weight
1942+
model_dict['model_id'] = table.loc[model_id]['model_id'].astype(int)
1943+
model_dict['rank'] = rank
1944+
model_dict['cost'] = table.loc[model_id]['cost']
1945+
model_dict['ensemble_weight'] = table.loc[model_id]['ensemble_weight']
1946+
rank += 1 # Incrementing rank by 1 for the next model
1947+
1948+
# The steps in the models pipeline are as follows:
1949+
# 'data_preprocessor': DataPreprocessor,
1950+
# 'balancing': Balancing,
1951+
# 'feature_preprocessor': FeaturePreprocessorChoice,
1952+
# 'classifier'/'regressor': ClassifierChoice/RegressorChoice (autosklearn wrapped model)
1953+
1954+
# For 'cv' (cross validation) strategy
1955+
if is_cv:
1956+
# Voting model created by cross validation
1957+
cv_voting_ensemble = model
1958+
model_dict['voting_model'] = cv_voting_ensemble
1959+
1960+
# List of models, each trained on one cv fold
1961+
cv_models = []
1962+
for cv_model in cv_voting_ensemble.estimators_:
1963+
estimator = dict(cv_model.steps)
1964+
1965+
# Adding sklearn model to the model dictionary
1966+
model_type, autosklearn_wrapped_model = cv_model.steps[-1]
1967+
estimator[f'sklearn_{model_type}'] = autosklearn_wrapped_model.choice.estimator
1968+
cv_models.append(estimator)
1969+
model_dict['estimators'] = cv_models
1970+
1971+
# For any other strategy
1972+
else:
1973+
steps = dict(model.steps)
1974+
model_dict.update(steps)
1975+
1976+
# Adding sklearn model to the model dictionary
1977+
model_type, autosklearn_wrapped_model = model.steps[-1]
1978+
model_dict[f'sklearn_{model_type}'] = autosklearn_wrapped_model.choice.estimator
1979+
1980+
# Insterting model_dict in the ensemble dictionary
1981+
ensemble_dict[model_id] = model_dict
1982+
1983+
return ensemble_dict
18491984

18501985
def _create_search_space(
18511986
self,

autosklearn/estimators.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,13 +537,74 @@ def score(self, X, y):
537537
return self.automl_.score(X, y)
538538

539539
def show_models(self):
540-
"""Return a representation of the final ensemble found by auto-sklearn.
540+
""" Returns a dictionary containing dictionaries of ensemble models.
541+
542+
Each model in the ensemble can be accessed by giving its ``model_id`` as key.
543+
544+
A model dictionary contains the following:
545+
546+
* ``"model_id"`` - The id given to a model by ``autosklearn``.
547+
* ``"rank"`` - The rank of the model based on it's ``"cost"``.
548+
* ``"cost"`` - The loss of the model on the validation set.
549+
* ``"ensemble_weight"`` - The weight given to the model in the ensemble.
550+
* ``"voting_model"`` - The ``cv_voting_ensemble`` model (for 'cv' resampling).
551+
* ``"estimators"`` - List of models (dicts) in ``cv_voting_ensemble`` (for 'cv' resampling).
552+
* ``"data_preprocessor"`` - The preprocessor used on the data.
553+
* ``"balancing"`` - The balancing used on the data (for classification).
554+
* ``"feature_preprocessor"`` - The preprocessor for features types.
555+
* ``"classifier"`` or ``"regressor"`` - The autosklearn wrapped classifier or regressor.
556+
* ``"sklearn_classifier"`` or ``"sklearn_regressor"`` - The sklearn classifier or regressor.
557+
558+
**Example**
559+
560+
.. code-block:: python
561+
562+
import sklearn.datasets
563+
import sklearn.metrics
564+
import autosklearn.regression
565+
566+
X, y = sklearn.datasets.load_diabetes(return_X_y=True)
567+
568+
automl = autosklearn.regression.AutoSklearnRegressor(
569+
time_left_for_this_task=120
570+
)
571+
automl.fit(X_train, y_train, dataset_name='diabetes')
572+
573+
ensemble_dict = automl.show_models()
574+
print(ensemble_dict)
575+
576+
Output:
577+
578+
.. code-block:: text
579+
580+
{
581+
25: {'model_id': 25.0,
582+
'rank': 1,
583+
'cost': 0.43667876507897496,
584+
'ensemble_weight': 0.38,
585+
'data_preprocessor': <autosklearn.pipeline.components.data_preprocessing....>,
586+
'feature_preprocessor': <autosklearn.pipeline.components....>,
587+
'regressor': <autosklearn.pipeline.components.regression....>,
588+
'sklearn_regressor': SGDRegressor(alpha=0.0006517033225329654,...)
589+
},
590+
6: {'model_id': 6.0,
591+
'rank': 2,
592+
'cost': 0.4550418898836528,
593+
'ensemble_weight': 0.3,
594+
'data_preprocessor': <autosklearn.pipeline.components.data_preprocessing....>,
595+
'feature_preprocessor': <autosklearn.pipeline.components....>,
596+
'regressor': <autosklearn.pipeline.components.regression....>,
597+
'sklearn_regressor': ARDRegression(alpha_1=0.0003701926442639788,...)
598+
}...
599+
}
541600
542601
Returns
543602
-------
544-
str
603+
Dict(int, Any) : dictionary of length = number of models in the ensemble
604+
A dictionary of models in the ensemble, where ``model_id`` is the key.
545605
546606
"""
607+
547608
return self.automl_.show_models()
548609

549610
def get_models_with_weights(self):

examples/20_basic/example_classification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
The following example shows how to fit a simple classification model with
88
*auto-sklearn*.
99
"""
10+
from pprint import pprint
11+
1012
import sklearn.datasets
1113
import sklearn.metrics
1214

@@ -42,7 +44,7 @@
4244
# Print the final ensemble constructed by auto-sklearn
4345
# ====================================================
4446

45-
print(automl.show_models())
47+
pprint(automl.show_models(), indent=4)
4648

4749
###########################################################################
4850
# Get the Score of the final ensemble

examples/20_basic/example_multilabel_classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
`here <https://scikit-learn.org/stable/modules/multiclass.html>`_.
99
"""
1010
import numpy as np
11+
from pprint import pprint
1112

1213
import sklearn.datasets
1314
import sklearn.metrics
@@ -65,7 +66,7 @@
6566
# Print the final ensemble constructed by auto-sklearn
6667
# ====================================================
6768

68-
print(automl.show_models())
69+
pprint(automl.show_models(), indent=4)
6970

7071
############################################################################
7172
# Print statistics about the auto-sklearn run

examples/20_basic/example_multioutput_regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*auto-sklearn*.
99
"""
1010
import numpy as numpy
11+
from pprint import pprint
1112

1213
from sklearn.datasets import make_regression
1314
from sklearn.metrics import r2_score
@@ -46,7 +47,7 @@
4647
# Print the final ensemble constructed by auto-sklearn
4748
# ====================================================
4849

49-
print(automl.show_models())
50+
pprint(automl.show_models(), indent=4)
5051

5152
###########################################################################
5253
# Get the Score of the final ensemble

examples/20_basic/example_regression.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
The following example shows how to fit a simple regression model with
88
*auto-sklearn*.
99
"""
10+
from pprint import pprint
11+
1012
import sklearn.datasets
1113
import sklearn.metrics
1214

@@ -43,7 +45,7 @@
4345
# Print the final ensemble constructed by auto-sklearn
4446
# ====================================================
4547

46-
print(automl.show_models())
48+
pprint(automl.show_models(), indent=4)
4749

4850
#####################################
4951
# Get the Score of the final ensemble

examples/40_advanced/example_get_pipeline_components.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
the sklearn models. This example illustrates how to interact
1515
with the sklearn components directly, in this case a PCA preprocessor.
1616
"""
17+
from pprint import pprint
18+
1719
import sklearn.datasets
1820
import sklearn.metrics
1921

@@ -62,10 +64,17 @@
6264
# `Ensemble Selection <https://www.cs.cornell.edu/~alexn/papers/shotgun.icml04.revised.rev2.pdf>`_
6365
# to construct ensembles in a post-hoc fashion. The ensemble is a linear
6466
# weighting of all models constructed during the hyperparameter optimization.
65-
# This prints the final ensemble. It is a list of tuples, each tuple being
66-
# the model weight in the ensemble and the model itself.
67-
68-
print(automl.show_models())
67+
# This prints the final ensemble. It is a dictionary where ``model_id`` of
68+
# each model is a key, and value is a dictionary containing information
69+
# of that model. A model's dict contains its ``'model_id'``, ``'rank'``,
70+
# ``'cost'``, ``'ensemble_weight'``, and the model itself. The model is
71+
# given by the ``'data_preprocessor'``, ``'feature_preprocessor'``,
72+
# ``'regressor'/'classifier'`` and ``'sklearn_regressor'/'sklearn_classifier'``
73+
# entries. But for the ``'cv'`` resampling strategy, the same for each cv
74+
# model is stored in the ``'estimators'`` list in the dict, along with the
75+
# ``'voting_model'``.
76+
77+
pprint(automl.show_models(), indent=4)
6978

7079
###########################################################################
7180
# Report statistics about the search

examples/40_advanced/example_interpretable_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
The following example shows how to inspect the models which *auto-sklearn*
88
optimizes over and how to restrict them to an interpretable subset.
99
"""
10+
from pprint import pprint
11+
1012
import autosklearn.classification
1113
import sklearn.datasets
1214
import sklearn.metrics
@@ -70,7 +72,7 @@
7072
# Print the final ensemble constructed by auto-sklearn
7173
# ====================================================
7274

73-
print(automl.show_models())
75+
pprint(automl.show_models(), indent=4)
7476

7577
###########################################################################
7678
# Get the Score of the final ensemble

examples/60_search/example_random_search.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
as yet another alternative optimizatino strategy.
1313
Both examples are intended to show how the optimization strategy in *auto-sklearn* can be adapted.
1414
""" # noqa (links are too long)
15+
from pprint import pprint
1516

1617
import sklearn.model_selection
1718
import sklearn.datasets
@@ -75,7 +76,7 @@ def get_roar_object_callback(
7576
print('#' * 80)
7677
print('Results for ROAR.')
7778
# Print the final ensemble constructed by auto-sklearn via ROAR.
78-
print(automl.show_models())
79+
pprint(automl.show_models(), indent=4)
7980
predictions = automl.predict(X_test)
8081
# Print statistics about the auto-sklearn run such as number of
8182
# iterations, number of models failed with a time out.
@@ -129,7 +130,7 @@ def get_random_search_object_callback(
129130
print('Results for random search.')
130131

131132
# Print the final ensemble constructed by auto-sklearn via random search.
132-
print(automl.show_models())
133+
pprint(automl.show_models(), indent=4)
133134

134135
# Print statistics about the auto-sklearn run such as number of
135136
# iterations, number of models failed with a time out.

examples/60_search/example_sequential.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
sequentially. The example below shows how to first fit the models and build the
99
ensembles afterwards.
1010
"""
11+
from pprint import pprint
1112

1213
import sklearn.model_selection
1314
import sklearn.datasets
@@ -48,7 +49,7 @@
4849
# Print the final ensemble constructed by auto-sklearn
4950
# ====================================================
5051

51-
print(automl.show_models())
52+
pprint(automl.show_models(), indent=4)
5253

5354
############################################################################
5455
# Get the Score of the final ensemble

0 commit comments

Comments
 (0)