Skip to content

Enabled pipeline fit #1096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 268 additions & 26 deletions autosklearn/automl.py

Large diffs are not rendered by default.

179 changes: 128 additions & 51 deletions autosklearn/estimators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
# -*- encoding: utf-8 -*-

from typing import Optional, Dict, List
from typing import Optional, Dict, List, Tuple, Union

from ConfigSpace.configuration_space import Configuration
import dask.distributed
import joblib
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.multiclass import type_of_target
from smac.runhistory.runhistory import RunInfo, RunValue

from autosklearn.data.validation import (
SUPPORTED_FEAT_TYPES,
SUPPORTED_TARGET_TYPES,
)
from autosklearn.pipeline.base import BasePipeline
from autosklearn.automl import AutoMLClassifier, AutoMLRegressor, AutoML
from autosklearn.metrics import Scorer
from autosklearn.util.backend import create
Expand Down Expand Up @@ -271,8 +278,15 @@ def __init__(
self.load_models = load_models

self.automl_ = None # type: Optional[AutoML]
# n_jobs after conversion to a number (b/c default is None)

# Handle the number of jobs and the time for them
self._n_jobs = None
if self.n_jobs is None or self.n_jobs == 1:
self._n_jobs = 1
elif self.n_jobs == -1:
self._n_jobs = joblib.cpu_count()
else:
self._n_jobs = self.n_jobs

super().__init__()

Expand All @@ -281,35 +295,24 @@ def __getstate__(self):
self.dask_client = None
return self.__dict__

def build_automl(
self,
seed: int,
ensemble_size: int,
initial_configurations_via_metalearning: int,
tmp_folder: str,
output_folder: str,
smac_scenario_args: Optional[Dict] = None,
):
def build_automl(self):

backend = create(
temporary_directory=tmp_folder,
output_directory=output_folder,
temporary_directory=self.tmp_folder,
output_directory=self.output_folder,
delete_tmp_folder_after_terminate=self.delete_tmp_folder_after_terminate,
delete_output_folder_after_terminate=self.delete_output_folder_after_terminate,
)

if smac_scenario_args is None:
smac_scenario_args = self.smac_scenario_args

automl = self._get_automl_class()(
backend=backend,
time_left_for_this_task=self.time_left_for_this_task,
per_run_time_limit=self.per_run_time_limit,
initial_configurations_via_metalearning=initial_configurations_via_metalearning,
ensemble_size=ensemble_size,
initial_configurations_via_metalearning=self.initial_configurations_via_metalearning,
ensemble_size=self.ensemble_size,
ensemble_nbest=self.ensemble_nbest,
max_models_on_disc=self.max_models_on_disc,
seed=seed,
seed=self.seed,
memory_limit=self.memory_limit,
include_estimators=self.include_estimators,
exclude_estimators=self.exclude_estimators,
Expand All @@ -321,7 +324,7 @@ def build_automl(
dask_client=self.dask_client,
get_smac_object_callback=self.get_smac_object_callback,
disable_evaluator_output=self.disable_evaluator_output,
smac_scenario_args=smac_scenario_args,
smac_scenario_args=self.smac_scenario_args,
logging_config=self.logging_config,
metadata_directory=self.metadata_directory,
metric=self.metric,
Expand All @@ -332,32 +335,82 @@ def build_automl(

def fit(self, **kwargs):

# Handle the number of jobs and the time for them
if self.n_jobs is None or self.n_jobs == 1:
self._n_jobs = 1
elif self.n_jobs == -1:
self._n_jobs = joblib.cpu_count()
else:
self._n_jobs = self.n_jobs

# Automatically set the cutoff time per task
if self.per_run_time_limit is None:
self.per_run_time_limit = self._n_jobs * self.time_left_for_this_task // 10

seed = self.seed
self.automl_ = self.build_automl(
seed=seed,
ensemble_size=self.ensemble_size,
initial_configurations_via_metalearning=(
self.initial_configurations_via_metalearning
),
tmp_folder=self.tmp_folder,
output_folder=self.output_folder,
)
if self.automl_ is None:
self.automl_ = self.build_automl()
self.automl_.fit(load_models=self.load_models, **kwargs)

return self

def fit_pipeline(
self,
X: SUPPORTED_FEAT_TYPES,
y: SUPPORTED_TARGET_TYPES,
config: Union[Configuration, Dict[str, Union[str, float, int]]],
dataset_name: Optional[str] = None,
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
y_test: Optional[SUPPORTED_TARGET_TYPES] = None,
feat_type: Optional[List[str]] = None,
*args,
**kwargs: Dict,
) -> Tuple[Optional[BasePipeline], RunInfo, RunValue]:
""" Fits and individual pipeline configuration and returns
the result to the user.

The Estimator constraints are honored, for example the resampling
strategy, or memory constraints, unless directly provided to the method.
By default, this method supports the same signature as fit(), and any extra
arguments are redirected to the TAE evaluation function, which allows for
further customization while building a pipeline.

Any additional argument provided is directly passed to the worker exercising the run.

Parameters
----------
X: array-like, shape = (n_samples, n_features)
The features used for training
y: array-like
The labels used for training
X_test: Optionalarray-like, shape = (n_samples, n_features)
If provided, the testing performance will be tracked on this features.
y_test: array-like
If provided, the testing performance will be tracked on this labels
config: Union[Configuration, Dict[str, Union[str, float, int]]]
A configuration object used to define the pipeline steps.
If a dictionary is passed, a configuration is created based on this dictionary.
dataset_name: Optional[str]
Name that will be used to tag the Auto-Sklearn run and identify the
Auto-Sklearn run
feat_type : list, optional (default=None)
List of str of `len(X.shape[1])` describing the attribute type.
Possible types are `Categorical` and `Numerical`. `Categorical`
attributes will be automatically One-Hot encoded. The values
used for a categorical attribute must be integers, obtained for
example by `sklearn.preprocessing.LabelEncoder
<http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html>`_.

Returns
-------
pipeline: Optional[BasePipeline]
The fitted pipeline. In case of failure while fitting the pipeline,
a None is returned.
run_info: RunInFo
A named tuple that contains the configuration launched
run_value: RunValue
A named tuple that contains the result of the run
"""
if self.automl_ is None:
self.automl_ = self.build_automl()
return self.automl_.fit_pipeline(X=X, y=y,
dataset_name=dataset_name,
config=config,
feat_type=feat_type,
X_test=X_test, y_test=y_test,
*args, **kwargs)

def fit_ensemble(self, y, task=None, precision=32,
dataset_name=None, ensemble_nbest=None,
ensemble_size=None):
Expand Down Expand Up @@ -401,17 +454,9 @@ def fit_ensemble(self, y, task=None, precision=32,
"""
if self.automl_ is None:
# Build a dummy automl object to call fit_ensemble
self.automl_ = self.build_automl(
seed=self.seed,
ensemble_size=(
ensemble_size
if ensemble_size is not None else
self.ensemble_size
),
initial_configurations_via_metalearning=0,
tmp_folder=self.tmp_folder,
output_folder=self.output_folder,
)
# The ensemble size is honored in the .automl_.fit_ensemble
# call
self.automl_ = self.build_automl()
self.automl_.fit_ensemble(
y=y,
task=task,
Expand Down Expand Up @@ -513,8 +558,40 @@ def sprint_statistics(self):
def _get_automl_class(self):
raise NotImplementedError()

def get_configuration_space(self, X, y):
return self.automl_.configuration_space
def get_configuration_space(
self,
X: SUPPORTED_FEAT_TYPES,
y: SUPPORTED_TARGET_TYPES,
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
y_test: Optional[SUPPORTED_TARGET_TYPES] = None,
dataset_name: Optional[str] = None,
):
"""
Returns the Configuration Space object, from which Auto-Sklearn
will sample configurations and build pipelines.

Parameters
----------
X : array-like or sparse matrix of shape = [n_samples, n_features]
Array with the training features, used to get characteristics like
data sparsity
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
Array with the problem labels
X_test : array-like or sparse matrix of shape = [n_samples, n_features]
Array with features used for performance estimation
y_test : array-like, shape = [n_samples] or [n_samples, n_outputs]
Array with the problem labels for the testing split
dataset_name: Optional[str]
A string to tag the Auto-Sklearn run
"""
if self.automl_ is None:
self.automl_ = self.build_automl()
return self.automl_.fit(
X, y,
X_test=X_test, y_test=y_test,
dataset_name=dataset_name,
only_return_configuration_space=True,
) if self.automl_.configuration_space is None else self.automl_.configuration_space


class AutoSklearnClassifier(AutoSklearnEstimator, ClassifierMixin):
Expand Down
38 changes: 9 additions & 29 deletions autosklearn/smbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import autosklearn.metalearning
from autosklearn.constants import MULTILABEL_CLASSIFICATION, \
BINARY_CLASSIFICATION, TASK_TYPES_TO_STRING, CLASSIFICATION_TASKS, \
REGRESSION_TASKS, MULTICLASS_CLASSIFICATION, REGRESSION, \
MULTIOUTPUT_REGRESSION
MULTICLASS_CLASSIFICATION, REGRESSION, MULTIOUTPUT_REGRESSION
from autosklearn.ensemble_builder import EnsembleBuilderManager
from autosklearn.metalearning.mismbo import suggest_via_metalearning
from autosklearn.data.abstract_data_manager import AbstractDataManager
from autosklearn.evaluation import ExecuteTaFuncWithQueue, get_cost_of_crash
from autosklearn.util.logging_ import get_named_client_logger
from autosklearn.util.parallel import preload_modules
from autosklearn.util.pipeline import parse_include_exclude_components
from autosklearn.metalearning.metalearning.meta_base import MetaBase
from autosklearn.metalearning.metafeatures.metafeatures import \
calculate_all_metafeatures_with_labels, calculate_all_metafeatures_encoded_labels
Expand Down Expand Up @@ -416,33 +416,13 @@ def run_smbo(self):
# evaluator, which takes into account that a run can be killed prior
# to the model being fully fitted; thus putting intermediate results
# into a queue and querying them once the time is over
exclude = dict()
include = dict()
if self.include_preprocessors is not None and self.exclude_preprocessors is not None:
raise ValueError('Cannot specify include_preprocessors and '
'exclude_preprocessors.')
elif self.include_preprocessors is not None:
include['feature_preprocessor'] = self.include_preprocessors
elif self.exclude_preprocessors is not None:
exclude['feature_preprocessor'] = self.exclude_preprocessors

if self.include_estimators is not None and self.exclude_estimators is not None:
raise ValueError('Cannot specify include_estimators and '
'exclude_estimators.')
elif self.include_estimators is not None:
if self.task in CLASSIFICATION_TASKS:
include['classifier'] = self.include_estimators
elif self.task in REGRESSION_TASKS:
include['regressor'] = self.include_estimators
else:
raise ValueError(self.task)
elif self.exclude_estimators is not None:
if self.task in CLASSIFICATION_TASKS:
exclude['classifier'] = self.exclude_estimators
elif self.task in REGRESSION_TASKS:
exclude['regressor'] = self.exclude_estimators
else:
raise ValueError(self.task)
include, exclude = parse_include_exclude_components(
task=self.task,
include_estimators=self.include_estimators,
exclude_estimators=self.exclude_estimators,
include_preprocessors=self.include_preprocessors,
exclude_preprocessors=self.exclude_preprocessors,
)

ta_kwargs = dict(
backend=copy.deepcopy(self.backend),
Expand Down
43 changes: 30 additions & 13 deletions autosklearn/util/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

from ConfigSpace.configuration_space import ConfigurationSpace

Expand All @@ -24,12 +24,13 @@
]


def get_configuration_space(info: Dict[str, Any],
include_estimators: Optional[List[str]] = None,
exclude_estimators: Optional[List[str]] = None,
include_preprocessors: Optional[List[str]] = None,
exclude_preprocessors: Optional[List[str]] = None
) -> ConfigurationSpace:
def parse_include_exclude_components(
task: int,
include_estimators: Optional[List[str]] = None,
exclude_estimators: Optional[List[str]] = None,
include_preprocessors: Optional[List[str]] = None,
exclude_preprocessors: Optional[List[str]] = None
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
exclude = dict()
include = dict()
if include_preprocessors is not None and \
Expand All @@ -46,19 +47,35 @@ def get_configuration_space(info: Dict[str, Any],
raise ValueError('Cannot specify include_estimators and '
'exclude_estimators.')
elif include_estimators is not None:
if info['task'] in CLASSIFICATION_TASKS:
if task in CLASSIFICATION_TASKS:
include['classifier'] = include_estimators
elif info['task'] in REGRESSION_TASKS:
elif task in REGRESSION_TASKS:
include['regressor'] = include_estimators
else:
raise ValueError(info['task'])
raise ValueError(task)
elif exclude_estimators is not None:
if info['task'] in CLASSIFICATION_TASKS:
if task in CLASSIFICATION_TASKS:
exclude['classifier'] = exclude_estimators
elif info['task'] in REGRESSION_TASKS:
elif task in REGRESSION_TASKS:
exclude['regressor'] = exclude_estimators
else:
raise ValueError(info['task'])
raise ValueError(task)
return include, exclude


def get_configuration_space(info: Dict[str, Any],
include_estimators: Optional[List[str]] = None,
exclude_estimators: Optional[List[str]] = None,
include_preprocessors: Optional[List[str]] = None,
exclude_preprocessors: Optional[List[str]] = None
) -> ConfigurationSpace:
include, exclude = parse_include_exclude_components(
task=info['task'],
include_estimators=include_estimators,
exclude_estimators=exclude_estimators,
include_preprocessors=include_preprocessors,
exclude_preprocessors=exclude_preprocessors,
)

if info['task'] in REGRESSION_TASKS:
return _get_regression_configuration_space(info, include, exclude)
Expand Down
1 change: 1 addition & 0 deletions doc/manual.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ aspects of its usage:
* `Iterating over the models <examples/40_advanced/example_get_pipeline_components.html>`_
* `Using custom metrics <examples/40_advanced/example_metrics.html>`_
* `Pandas Train and Test inputs <examples/40_advanced/example_pandas_train_test.html>`_
* `Train a single configuration <examples/40_advanced/example_single_configuration.html>`_
* `Resampling strategies <examples/40_advanced/example_resampling.html>`_
* `Parallel usage (manual) <examples/60_search/example_parallel_manual_spawning.html>`_
* `Parallel usage (n_jobs) <examples/60_search/example_parallel_n_jobs.html>`_
Expand Down
Loading