Skip to content

Commit b3867b4

Browse files
Enabled pipeline fit (#1096)
* Enabled pipeline fit * Feedback from comments * Feedback from comments * clean pipeline parsing * Feedback from PR
1 parent b46a918 commit b3867b4

File tree

8 files changed

+633
-148
lines changed

8 files changed

+633
-148
lines changed

autosklearn/automl.py

Lines changed: 268 additions & 26 deletions
Large diffs are not rendered by default.

autosklearn/estimators.py

Lines changed: 128 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
# -*- encoding: utf-8 -*-
22

3-
from typing import Optional, Dict, List
3+
from typing import Optional, Dict, List, Tuple, Union
44

5+
from ConfigSpace.configuration_space import Configuration
56
import dask.distributed
67
import joblib
78
import numpy as np
89
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
910
from sklearn.utils.multiclass import type_of_target
11+
from smac.runhistory.runhistory import RunInfo, RunValue
1012

13+
from autosklearn.data.validation import (
14+
SUPPORTED_FEAT_TYPES,
15+
SUPPORTED_TARGET_TYPES,
16+
)
17+
from autosklearn.pipeline.base import BasePipeline
1118
from autosklearn.automl import AutoMLClassifier, AutoMLRegressor, AutoML
1219
from autosklearn.metrics import Scorer
1320
from autosklearn.util.backend import create
@@ -271,8 +278,15 @@ def __init__(
271278
self.load_models = load_models
272279

273280
self.automl_ = None # type: Optional[AutoML]
274-
# n_jobs after conversion to a number (b/c default is None)
281+
282+
# Handle the number of jobs and the time for them
275283
self._n_jobs = None
284+
if self.n_jobs is None or self.n_jobs == 1:
285+
self._n_jobs = 1
286+
elif self.n_jobs == -1:
287+
self._n_jobs = joblib.cpu_count()
288+
else:
289+
self._n_jobs = self.n_jobs
276290

277291
super().__init__()
278292

@@ -281,35 +295,24 @@ def __getstate__(self):
281295
self.dask_client = None
282296
return self.__dict__
283297

284-
def build_automl(
285-
self,
286-
seed: int,
287-
ensemble_size: int,
288-
initial_configurations_via_metalearning: int,
289-
tmp_folder: str,
290-
output_folder: str,
291-
smac_scenario_args: Optional[Dict] = None,
292-
):
298+
def build_automl(self):
293299

294300
backend = create(
295-
temporary_directory=tmp_folder,
296-
output_directory=output_folder,
301+
temporary_directory=self.tmp_folder,
302+
output_directory=self.output_folder,
297303
delete_tmp_folder_after_terminate=self.delete_tmp_folder_after_terminate,
298304
delete_output_folder_after_terminate=self.delete_output_folder_after_terminate,
299305
)
300306

301-
if smac_scenario_args is None:
302-
smac_scenario_args = self.smac_scenario_args
303-
304307
automl = self._get_automl_class()(
305308
backend=backend,
306309
time_left_for_this_task=self.time_left_for_this_task,
307310
per_run_time_limit=self.per_run_time_limit,
308-
initial_configurations_via_metalearning=initial_configurations_via_metalearning,
309-
ensemble_size=ensemble_size,
311+
initial_configurations_via_metalearning=self.initial_configurations_via_metalearning,
312+
ensemble_size=self.ensemble_size,
310313
ensemble_nbest=self.ensemble_nbest,
311314
max_models_on_disc=self.max_models_on_disc,
312-
seed=seed,
315+
seed=self.seed,
313316
memory_limit=self.memory_limit,
314317
include_estimators=self.include_estimators,
315318
exclude_estimators=self.exclude_estimators,
@@ -321,7 +324,7 @@ def build_automl(
321324
dask_client=self.dask_client,
322325
get_smac_object_callback=self.get_smac_object_callback,
323326
disable_evaluator_output=self.disable_evaluator_output,
324-
smac_scenario_args=smac_scenario_args,
327+
smac_scenario_args=self.smac_scenario_args,
325328
logging_config=self.logging_config,
326329
metadata_directory=self.metadata_directory,
327330
metric=self.metric,
@@ -332,32 +335,82 @@ def build_automl(
332335

333336
def fit(self, **kwargs):
334337

335-
# Handle the number of jobs and the time for them
336-
if self.n_jobs is None or self.n_jobs == 1:
337-
self._n_jobs = 1
338-
elif self.n_jobs == -1:
339-
self._n_jobs = joblib.cpu_count()
340-
else:
341-
self._n_jobs = self.n_jobs
342-
343338
# Automatically set the cutoff time per task
344339
if self.per_run_time_limit is None:
345340
self.per_run_time_limit = self._n_jobs * self.time_left_for_this_task // 10
346341

347-
seed = self.seed
348-
self.automl_ = self.build_automl(
349-
seed=seed,
350-
ensemble_size=self.ensemble_size,
351-
initial_configurations_via_metalearning=(
352-
self.initial_configurations_via_metalearning
353-
),
354-
tmp_folder=self.tmp_folder,
355-
output_folder=self.output_folder,
356-
)
342+
if self.automl_ is None:
343+
self.automl_ = self.build_automl()
357344
self.automl_.fit(load_models=self.load_models, **kwargs)
358345

359346
return self
360347

348+
def fit_pipeline(
349+
self,
350+
X: SUPPORTED_FEAT_TYPES,
351+
y: SUPPORTED_TARGET_TYPES,
352+
config: Union[Configuration, Dict[str, Union[str, float, int]]],
353+
dataset_name: Optional[str] = None,
354+
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
355+
y_test: Optional[SUPPORTED_TARGET_TYPES] = None,
356+
feat_type: Optional[List[str]] = None,
357+
*args,
358+
**kwargs: Dict,
359+
) -> Tuple[Optional[BasePipeline], RunInfo, RunValue]:
360+
""" Fits and individual pipeline configuration and returns
361+
the result to the user.
362+
363+
The Estimator constraints are honored, for example the resampling
364+
strategy, or memory constraints, unless directly provided to the method.
365+
By default, this method supports the same signature as fit(), and any extra
366+
arguments are redirected to the TAE evaluation function, which allows for
367+
further customization while building a pipeline.
368+
369+
Any additional argument provided is directly passed to the worker exercising the run.
370+
371+
Parameters
372+
----------
373+
X: array-like, shape = (n_samples, n_features)
374+
The features used for training
375+
y: array-like
376+
The labels used for training
377+
X_test: Optionalarray-like, shape = (n_samples, n_features)
378+
If provided, the testing performance will be tracked on this features.
379+
y_test: array-like
380+
If provided, the testing performance will be tracked on this labels
381+
config: Union[Configuration, Dict[str, Union[str, float, int]]]
382+
A configuration object used to define the pipeline steps.
383+
If a dictionary is passed, a configuration is created based on this dictionary.
384+
dataset_name: Optional[str]
385+
Name that will be used to tag the Auto-Sklearn run and identify the
386+
Auto-Sklearn run
387+
feat_type : list, optional (default=None)
388+
List of str of `len(X.shape[1])` describing the attribute type.
389+
Possible types are `Categorical` and `Numerical`. `Categorical`
390+
attributes will be automatically One-Hot encoded. The values
391+
used for a categorical attribute must be integers, obtained for
392+
example by `sklearn.preprocessing.LabelEncoder
393+
<http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html>`_.
394+
395+
Returns
396+
-------
397+
pipeline: Optional[BasePipeline]
398+
The fitted pipeline. In case of failure while fitting the pipeline,
399+
a None is returned.
400+
run_info: RunInFo
401+
A named tuple that contains the configuration launched
402+
run_value: RunValue
403+
A named tuple that contains the result of the run
404+
"""
405+
if self.automl_ is None:
406+
self.automl_ = self.build_automl()
407+
return self.automl_.fit_pipeline(X=X, y=y,
408+
dataset_name=dataset_name,
409+
config=config,
410+
feat_type=feat_type,
411+
X_test=X_test, y_test=y_test,
412+
*args, **kwargs)
413+
361414
def fit_ensemble(self, y, task=None, precision=32,
362415
dataset_name=None, ensemble_nbest=None,
363416
ensemble_size=None):
@@ -401,17 +454,9 @@ def fit_ensemble(self, y, task=None, precision=32,
401454
"""
402455
if self.automl_ is None:
403456
# Build a dummy automl object to call fit_ensemble
404-
self.automl_ = self.build_automl(
405-
seed=self.seed,
406-
ensemble_size=(
407-
ensemble_size
408-
if ensemble_size is not None else
409-
self.ensemble_size
410-
),
411-
initial_configurations_via_metalearning=0,
412-
tmp_folder=self.tmp_folder,
413-
output_folder=self.output_folder,
414-
)
457+
# The ensemble size is honored in the .automl_.fit_ensemble
458+
# call
459+
self.automl_ = self.build_automl()
415460
self.automl_.fit_ensemble(
416461
y=y,
417462
task=task,
@@ -513,8 +558,40 @@ def sprint_statistics(self):
513558
def _get_automl_class(self):
514559
raise NotImplementedError()
515560

516-
def get_configuration_space(self, X, y):
517-
return self.automl_.configuration_space
561+
def get_configuration_space(
562+
self,
563+
X: SUPPORTED_FEAT_TYPES,
564+
y: SUPPORTED_TARGET_TYPES,
565+
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
566+
y_test: Optional[SUPPORTED_TARGET_TYPES] = None,
567+
dataset_name: Optional[str] = None,
568+
):
569+
"""
570+
Returns the Configuration Space object, from which Auto-Sklearn
571+
will sample configurations and build pipelines.
572+
573+
Parameters
574+
----------
575+
X : array-like or sparse matrix of shape = [n_samples, n_features]
576+
Array with the training features, used to get characteristics like
577+
data sparsity
578+
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
579+
Array with the problem labels
580+
X_test : array-like or sparse matrix of shape = [n_samples, n_features]
581+
Array with features used for performance estimation
582+
y_test : array-like, shape = [n_samples] or [n_samples, n_outputs]
583+
Array with the problem labels for the testing split
584+
dataset_name: Optional[str]
585+
A string to tag the Auto-Sklearn run
586+
"""
587+
if self.automl_ is None:
588+
self.automl_ = self.build_automl()
589+
return self.automl_.fit(
590+
X, y,
591+
X_test=X_test, y_test=y_test,
592+
dataset_name=dataset_name,
593+
only_return_configuration_space=True,
594+
) if self.automl_.configuration_space is None else self.automl_.configuration_space
518595

519596

520597
class AutoSklearnClassifier(AutoSklearnEstimator, ClassifierMixin):

autosklearn/smbo.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
import autosklearn.metalearning
2424
from autosklearn.constants import MULTILABEL_CLASSIFICATION, \
2525
BINARY_CLASSIFICATION, TASK_TYPES_TO_STRING, CLASSIFICATION_TASKS, \
26-
REGRESSION_TASKS, MULTICLASS_CLASSIFICATION, REGRESSION, \
27-
MULTIOUTPUT_REGRESSION
26+
MULTICLASS_CLASSIFICATION, REGRESSION, MULTIOUTPUT_REGRESSION
2827
from autosklearn.ensemble_builder import EnsembleBuilderManager
2928
from autosklearn.metalearning.mismbo import suggest_via_metalearning
3029
from autosklearn.data.abstract_data_manager import AbstractDataManager
3130
from autosklearn.evaluation import ExecuteTaFuncWithQueue, get_cost_of_crash
3231
from autosklearn.util.logging_ import get_named_client_logger
3332
from autosklearn.util.parallel import preload_modules
33+
from autosklearn.util.pipeline import parse_include_exclude_components
3434
from autosklearn.metalearning.metalearning.meta_base import MetaBase
3535
from autosklearn.metalearning.metafeatures.metafeatures import \
3636
calculate_all_metafeatures_with_labels, calculate_all_metafeatures_encoded_labels
@@ -414,33 +414,13 @@ def run_smbo(self):
414414
# evaluator, which takes into account that a run can be killed prior
415415
# to the model being fully fitted; thus putting intermediate results
416416
# into a queue and querying them once the time is over
417-
exclude = dict()
418-
include = dict()
419-
if self.include_preprocessors is not None and self.exclude_preprocessors is not None:
420-
raise ValueError('Cannot specify include_preprocessors and '
421-
'exclude_preprocessors.')
422-
elif self.include_preprocessors is not None:
423-
include['feature_preprocessor'] = self.include_preprocessors
424-
elif self.exclude_preprocessors is not None:
425-
exclude['feature_preprocessor'] = self.exclude_preprocessors
426-
427-
if self.include_estimators is not None and self.exclude_estimators is not None:
428-
raise ValueError('Cannot specify include_estimators and '
429-
'exclude_estimators.')
430-
elif self.include_estimators is not None:
431-
if self.task in CLASSIFICATION_TASKS:
432-
include['classifier'] = self.include_estimators
433-
elif self.task in REGRESSION_TASKS:
434-
include['regressor'] = self.include_estimators
435-
else:
436-
raise ValueError(self.task)
437-
elif self.exclude_estimators is not None:
438-
if self.task in CLASSIFICATION_TASKS:
439-
exclude['classifier'] = self.exclude_estimators
440-
elif self.task in REGRESSION_TASKS:
441-
exclude['regressor'] = self.exclude_estimators
442-
else:
443-
raise ValueError(self.task)
417+
include, exclude = parse_include_exclude_components(
418+
task=self.task,
419+
include_estimators=self.include_estimators,
420+
exclude_estimators=self.exclude_estimators,
421+
include_preprocessors=self.include_preprocessors,
422+
exclude_preprocessors=self.exclude_preprocessors,
423+
)
444424

445425
ta_kwargs = dict(
446426
backend=copy.deepcopy(self.backend),

autosklearn/util/pipeline.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- encoding: utf-8 -*-
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, Tuple
33

44
from ConfigSpace.configuration_space import ConfigurationSpace
55

@@ -24,12 +24,13 @@
2424
]
2525

2626

27-
def get_configuration_space(info: Dict[str, Any],
28-
include_estimators: Optional[List[str]] = None,
29-
exclude_estimators: Optional[List[str]] = None,
30-
include_preprocessors: Optional[List[str]] = None,
31-
exclude_preprocessors: Optional[List[str]] = None
32-
) -> ConfigurationSpace:
27+
def parse_include_exclude_components(
28+
task: int,
29+
include_estimators: Optional[List[str]] = None,
30+
exclude_estimators: Optional[List[str]] = None,
31+
include_preprocessors: Optional[List[str]] = None,
32+
exclude_preprocessors: Optional[List[str]] = None
33+
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
3334
exclude = dict()
3435
include = dict()
3536
if include_preprocessors is not None and \
@@ -46,19 +47,35 @@ def get_configuration_space(info: Dict[str, Any],
4647
raise ValueError('Cannot specify include_estimators and '
4748
'exclude_estimators.')
4849
elif include_estimators is not None:
49-
if info['task'] in CLASSIFICATION_TASKS:
50+
if task in CLASSIFICATION_TASKS:
5051
include['classifier'] = include_estimators
51-
elif info['task'] in REGRESSION_TASKS:
52+
elif task in REGRESSION_TASKS:
5253
include['regressor'] = include_estimators
5354
else:
54-
raise ValueError(info['task'])
55+
raise ValueError(task)
5556
elif exclude_estimators is not None:
56-
if info['task'] in CLASSIFICATION_TASKS:
57+
if task in CLASSIFICATION_TASKS:
5758
exclude['classifier'] = exclude_estimators
58-
elif info['task'] in REGRESSION_TASKS:
59+
elif task in REGRESSION_TASKS:
5960
exclude['regressor'] = exclude_estimators
6061
else:
61-
raise ValueError(info['task'])
62+
raise ValueError(task)
63+
return include, exclude
64+
65+
66+
def get_configuration_space(info: Dict[str, Any],
67+
include_estimators: Optional[List[str]] = None,
68+
exclude_estimators: Optional[List[str]] = None,
69+
include_preprocessors: Optional[List[str]] = None,
70+
exclude_preprocessors: Optional[List[str]] = None
71+
) -> ConfigurationSpace:
72+
include, exclude = parse_include_exclude_components(
73+
task=info['task'],
74+
include_estimators=include_estimators,
75+
exclude_estimators=exclude_estimators,
76+
include_preprocessors=include_preprocessors,
77+
exclude_preprocessors=exclude_preprocessors,
78+
)
6279

6380
if info['task'] in REGRESSION_TASKS:
6481
return _get_regression_configuration_space(info, include, exclude)

doc/manual.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ aspects of its usage:
2222
* `Iterating over the models <examples/40_advanced/example_get_pipeline_components.html>`_
2323
* `Using custom metrics <examples/40_advanced/example_metrics.html>`_
2424
* `Pandas Train and Test inputs <examples/40_advanced/example_pandas_train_test.html>`_
25+
* `Train a single configuration <examples/40_advanced/example_single_configuration.html>`_
2526
* `Resampling strategies <examples/40_advanced/example_resampling.html>`_
2627
* `Parallel usage (manual) <examples/60_search/example_parallel_manual_spawning.html>`_
2728
* `Parallel usage (n_jobs) <examples/60_search/example_parallel_n_jobs.html>`_

0 commit comments

Comments
 (0)