Skip to content

Commit b90d61d

Browse files
Fix the ensemble_size == 0 error in automl.py (#1369)
* Fix the ensemble == 0 error in fit_ensemble and show_models function by adding a valueError to the former and giving a warning and returning empty dictionary in the latter * Update automl.py * Two tests for ensemble_size == 0 cases Added two tests to check if the automl.fit_ensemble() raises error when ensemble_size == 0 and if show_models() returns empty dictionary when ensemble_size == 0 * Update automl.py * Update test_automl.py Test for checking if the show_models() functions raise an error if models are not fitted. * Update automl.py Add a function __sklearn_is_fitted__() which returns the boolean value of self.fitted(). And add the check for model fitting in show_models() function. * Update autosklearn/automl.py * Formatting changes to clear all the pre-commit tests Co-authored-by: Eddie Bergman <eddiebergmanhs@gmail.com>
1 parent 3a1d8f2 commit b90d61d

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

autosklearn/automl.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import time
1414
import unittest.mock
1515
import uuid
16+
import warnings
1617

1718
import dask
1819
import dask.distributed
@@ -302,6 +303,7 @@ def __init__(
302303

303304
# The ensemble performance history through time
304305
self.ensemble_performance_history = []
306+
self.fitted = False
305307

306308
# Single core, local runs should use fork
307309
# to prevent the __main__ requirements in
@@ -348,7 +350,7 @@ def _create_dask_client(self):
348350
processes=False,
349351
threads_per_worker=1,
350352
# We use the temporal directory to save the
351-
# dask workers, because deleting workers
353+
# dask workers, because deleting workers takes
352354
# more time than deleting backend directories
353355
# This prevent an error saying that the worker
354356
# file was deleted, so the client could not close
@@ -562,7 +564,7 @@ def fit(
562564
# "multiclass" be mean either REGRESSION or MULTICLASS_CLASSIFICATION,
563565
# and so this is where the subclasses are used to determine which.
564566
# However, this could also be deduced from the `is_classification`
565-
# paramaeter.
567+
# parameter.
566568
#
567569
# In the future, there is little need for the subclasses of `AutoML`
568570
# and no need for the `task` parameter. The extra functionality
@@ -1068,9 +1070,13 @@ def fit(
10681070
self._logger.info("Finished loading models...")
10691071

10701072
self._fit_cleanup()
1073+
self.fitted = True
10711074

10721075
return self
10731076

1077+
def __sklearn_is_fitted__(self) -> bool:
1078+
return self.fitted
1079+
10741080
def _fit_cleanup(self):
10751081
self._logger.info("Closing the dask infrastructure")
10761082
self._close_dask_client()
@@ -1481,6 +1487,10 @@ def fit_ensemble(
14811487
ensemble_nbest=None,
14821488
ensemble_size=None,
14831489
):
1490+
# check for the case when ensemble_size is less than 0
1491+
if not ensemble_size > 0:
1492+
raise ValueError("ensemble_size must be greater than 0 for fit_ensemble")
1493+
14841494
# AutoSklearn does not handle sparse y for now
14851495
y = convert_if_sparse(y)
14861496

@@ -1971,9 +1981,21 @@ def show_models(self) -> Dict[int, Any]:
19711981
-------
19721982
Dict(int, Any) : dictionary of length = number of models in the ensemble
19731983
A dictionary of models in the ensemble, where ``model_id`` is the key.
1974-
19751984
""" # noqa: E501
19761985
ensemble_dict = {}
1986+
# check for condition whether autosklearn is fitted if not raise runtime error
1987+
if not self.__sklearn_is_fitted__():
1988+
raise RuntimeError("AutoSklearn has not been fitted")
1989+
1990+
# check for ensemble_size == 0
1991+
if self._ensemble_size == 0:
1992+
warnings.warn("No models in the ensemble. Kindly check the ensemble size.")
1993+
return ensemble_dict
1994+
1995+
# check for condition when ensemble_size > 0 but there is no ensemble to load
1996+
if self.ensemble_ is None:
1997+
warnings.warn("No ensemble found. Returning empty dictionary.")
1998+
return ensemble_dict
19771999

19782000
def has_key(rv, key):
19792001
return rv.additional_info and key in rv.additional_info

test/test_automl/test_automl.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,48 @@ def test_fit(dask_client):
8888
del automl
8989

9090

91+
def test_ensemble_size_zero():
92+
"""Test if automl.fit_ensemble raises error when ensemble_size == 0"""
93+
X_train, Y_train, X_test, Y_test = putil.get_dataset("iris")
94+
automl = autosklearn.automl.AutoML(
95+
seed=0,
96+
time_left_for_this_task=30,
97+
per_run_time_limit=5,
98+
metric=accuracy,
99+
ensemble_size=0,
100+
)
101+
automl.fit(X_train, Y_train, task=MULTICLASS_CLASSIFICATION)
102+
with pytest.raises(ValueError):
103+
automl.fit_ensemble(Y_test, ensemble_size=0)
104+
105+
106+
def test_empty_dict_in_show_models():
107+
"""Test if show_models() returns empty dictionary when ensemble_size == 0"""
108+
X_train, Y_train, X_test, Y_test = putil.get_dataset("iris")
109+
automl = autosklearn.automl.AutoMLClassifier(
110+
seed=0,
111+
time_left_for_this_task=30,
112+
per_run_time_limit=5,
113+
metric=accuracy,
114+
ensemble_size=0,
115+
)
116+
automl.fit(X_train, Y_train)
117+
assert automl.show_models() == {}
118+
119+
120+
def test_fitted_models_in_show_models():
121+
X_train, Y_train, X_test, Y_test = putil.get_dataset("iris")
122+
automl = autosklearn.automl.AutoMLClassifier(
123+
seed=0,
124+
time_left_for_this_task=30,
125+
per_run_time_limit=5,
126+
metric=accuracy,
127+
ensemble_size=0,
128+
)
129+
with pytest.raises(RuntimeError, match="AutoSklearn has not been fitted"):
130+
automl.show_models()
131+
132+
91133
def test_fit_roar(dask_client_single_worker):
92134
def get_roar_object_callback(
93135
scenario_dict, seed, ta, ta_kwargs, dask_client, n_jobs, **kwargs

0 commit comments

Comments
 (0)