|
13 | 13 | import time
|
14 | 14 | import unittest.mock
|
15 | 15 | import uuid
|
| 16 | +import warnings |
16 | 17 |
|
17 | 18 | import dask
|
18 | 19 | import dask.distributed
|
@@ -302,6 +303,7 @@ def __init__(
|
302 | 303 |
|
303 | 304 | # The ensemble performance history through time
|
304 | 305 | self.ensemble_performance_history = []
|
| 306 | + self.fitted = False |
305 | 307 |
|
306 | 308 | # Single core, local runs should use fork
|
307 | 309 | # to prevent the __main__ requirements in
|
@@ -348,7 +350,7 @@ def _create_dask_client(self):
|
348 | 350 | processes=False,
|
349 | 351 | threads_per_worker=1,
|
350 | 352 | # We use the temporal directory to save the
|
351 |
| - # dask workers, because deleting workers |
| 353 | + # dask workers, because deleting workers takes |
352 | 354 | # more time than deleting backend directories
|
353 | 355 | # This prevent an error saying that the worker
|
354 | 356 | # file was deleted, so the client could not close
|
@@ -562,7 +564,7 @@ def fit(
|
562 | 564 | # "multiclass" be mean either REGRESSION or MULTICLASS_CLASSIFICATION,
|
563 | 565 | # and so this is where the subclasses are used to determine which.
|
564 | 566 | # However, this could also be deduced from the `is_classification`
|
565 |
| - # paramaeter. |
| 567 | + # parameter. |
566 | 568 | #
|
567 | 569 | # In the future, there is little need for the subclasses of `AutoML`
|
568 | 570 | # and no need for the `task` parameter. The extra functionality
|
@@ -1068,9 +1070,13 @@ def fit(
|
1068 | 1070 | self._logger.info("Finished loading models...")
|
1069 | 1071 |
|
1070 | 1072 | self._fit_cleanup()
|
| 1073 | + self.fitted = True |
1071 | 1074 |
|
1072 | 1075 | return self
|
1073 | 1076 |
|
| 1077 | + def __sklearn_is_fitted__(self) -> bool: |
| 1078 | + return self.fitted |
| 1079 | + |
1074 | 1080 | def _fit_cleanup(self):
|
1075 | 1081 | self._logger.info("Closing the dask infrastructure")
|
1076 | 1082 | self._close_dask_client()
|
@@ -1481,6 +1487,10 @@ def fit_ensemble(
|
1481 | 1487 | ensemble_nbest=None,
|
1482 | 1488 | ensemble_size=None,
|
1483 | 1489 | ):
|
| 1490 | + # check for the case when ensemble_size is less than 0 |
| 1491 | + if not ensemble_size > 0: |
| 1492 | + raise ValueError("ensemble_size must be greater than 0 for fit_ensemble") |
| 1493 | + |
1484 | 1494 | # AutoSklearn does not handle sparse y for now
|
1485 | 1495 | y = convert_if_sparse(y)
|
1486 | 1496 |
|
@@ -1971,9 +1981,21 @@ def show_models(self) -> Dict[int, Any]:
|
1971 | 1981 | -------
|
1972 | 1982 | Dict(int, Any) : dictionary of length = number of models in the ensemble
|
1973 | 1983 | A dictionary of models in the ensemble, where ``model_id`` is the key.
|
1974 |
| -
|
1975 | 1984 | """ # noqa: E501
|
1976 | 1985 | ensemble_dict = {}
|
| 1986 | + # check for condition whether autosklearn is fitted if not raise runtime error |
| 1987 | + if not self.__sklearn_is_fitted__(): |
| 1988 | + raise RuntimeError("AutoSklearn has not been fitted") |
| 1989 | + |
| 1990 | + # check for ensemble_size == 0 |
| 1991 | + if self._ensemble_size == 0: |
| 1992 | + warnings.warn("No models in the ensemble. Kindly check the ensemble size.") |
| 1993 | + return ensemble_dict |
| 1994 | + |
| 1995 | + # check for condition when ensemble_size > 0 but there is no ensemble to load |
| 1996 | + if self.ensemble_ is None: |
| 1997 | + warnings.warn("No ensemble found. Returning empty dictionary.") |
| 1998 | + return ensemble_dict |
1977 | 1999 |
|
1978 | 2000 | def has_key(rv, key):
|
1979 | 2001 | return rv.additional_info and key in rv.additional_info
|
|
0 commit comments