Skip to content

Commit dce6a5c

Browse files
authored
Adding tabular regression pipeline (#85)
* removed old supported_tasks dictionary from heads, added some docstrings and some small fixes * removed old supported_tasks attribute and updated doc strings in base backbone and base head components * removed old supported_tasks attribute from network backbones * put time series backbones in separate files, add doc strings and refactored search space arguments * split image networks into separate files, add doc strings and refactor search space * fix typo * add an intial simple backbone test similar to the network head test * fix flake8 * fixed imports in backbones and heads * added new network backbone and head tests * enabled tests for adding custom backbones and heads, added required properties to base head and base backbone * adding tabular regression pipeline * fix flake8 * adding tabular regression pipeline * fix flake8 * fix regression test * fix indentation and comments, undo change in base network * pipeline fitting tests now check the expected output shape dynamically based on the input data * refactored trainer tests, added trainer test for regression * remove regression from mixup unitest * use pandas unique instead of numpy * [IMPORTANT] added proper target casting based on task type to base trainer * adding tabular regression task to api * adding tabular regression example, some small fixes * new/more tests for tabular regression * fix mypy and flake8 errors from merge * fix issues with new weighted loss and regression tasks * change tabular column transformer to use net fit_dictionary_tabular fixture * fixing tests, replaced num_classes with output_shape * fixes after merge * adding voting regressor wrapper * fix mypy and flake * updated example * lower r2 target * address comments * increasing timeout * increase number of labels in test_losses because it occasionally failed if one class was not in the labels * lower regression lr in score test until seeding properly works * fix randomization in feature validator test
1 parent b48b952 commit dce6a5c

33 files changed

+1420
-478
lines changed

.github/workflows/pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
- name: Run tests
3030
run: |
3131
if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml'; fi
32-
python -m pytest --durations=20 --timeout=300 --timeout-method=thread -v $codecov test
32+
python -m pytest --durations=20 --timeout=500 --timeout-method=thread -v $codecov test
3333
- name: Check for files left behind by test
3434
if: ${{ always() }}
3535
run: |

autoPyTorch/api/base_task.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def send_warnings_to_log(
7373
with warnings.catch_warnings():
7474
warnings.showwarning = send_warnings_to_log
7575
if task in REGRESSION_TASKS:
76-
prediction = pipeline.predict(X_, batch_size=batch_size)
76+
# Voting regressor does not support batch size
77+
prediction = pipeline.predict(X_)
7778
else:
7879
# Voting classifier predict proba does not support batch size
7980
prediction = pipeline.predict_proba(X_)
@@ -161,7 +162,7 @@ def __init__(
161162
delete_tmp_folder_after_terminate=delete_tmp_folder_after_terminate,
162163
delete_output_folder_after_terminate=delete_output_folder_after_terminate,
163164
)
164-
self.task_type = task_type
165+
self.task_type = task_type or ""
165166
self._stopwatch = StopWatch()
166167

167168
self.pipeline_options = replace_string_bool_to_bool(json.load(open(
@@ -789,7 +790,7 @@ def _search(
789790
max_models_on_disc=self.max_models_on_disc,
790791
seed=self.seed,
791792
max_iterations=None,
792-
read_at_most=np.inf,
793+
read_at_most=sys.maxsize,
793794
ensemble_memory_limit=self._memory_limit,
794795
random_state=self.seed,
795796
precision=precision,
@@ -1050,7 +1051,7 @@ def predict(
10501051

10511052
all_predictions = joblib.Parallel(n_jobs=n_jobs)(
10521053
joblib.delayed(_pipeline_predict)(
1053-
models[identifier], X_test, batch_size, self._logger, self.task_type
1054+
models[identifier], X_test, batch_size, self._logger, STRING_TO_TASK_TYPES[self.task_type]
10541055
)
10551056
for identifier in self.ensemble_.get_selected_model_identifiers()
10561057
)
@@ -1064,17 +1065,6 @@ def predict(
10641065

10651066
predictions = self.ensemble_.predict(all_predictions)
10661067

1067-
if self.task_type in REGRESSION_TASKS:
1068-
# Make sure prediction probabilities
1069-
# are within a valid range
1070-
# Individual models are checked in _pipeline_predict
1071-
if (
1072-
(predictions >= 0).all() and (predictions <= 1).all()
1073-
):
1074-
raise ValueError("For ensemble {}, prediction probability not within [0, 1]!".format(
1075-
self.ensemble_)
1076-
)
1077-
10781068
self._clean_logger()
10791069

10801070
return predictions

autoPyTorch/api/tabular_regression.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import os
2+
import uuid
3+
from typing import Any, Callable, Dict, List, Optional, Union
4+
5+
import numpy as np
6+
7+
import pandas as pd
8+
9+
from autoPyTorch.api.base_task import BaseTask
10+
from autoPyTorch.constants import (
11+
TABULAR_REGRESSION,
12+
TASK_TYPES_TO_STRING
13+
)
14+
from autoPyTorch.data.tabular_validator import TabularInputValidator
15+
from autoPyTorch.datasets.base_dataset import BaseDataset
16+
from autoPyTorch.datasets.resampling_strategy import (
17+
CrossValTypes,
18+
HoldoutValTypes,
19+
)
20+
from autoPyTorch.datasets.tabular_dataset import TabularDataset
21+
from autoPyTorch.pipeline.tabular_regression import TabularRegressionPipeline
22+
from autoPyTorch.utils.backend import Backend
23+
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
24+
25+
26+
class TabularRegressionTask(BaseTask):
27+
"""
28+
Tabular Regression API to the pipelines.
29+
Args:
30+
seed (int): seed to be used for reproducibility.
31+
n_jobs (int), (default=1): number of consecutive processes to spawn.
32+
logging_config (Optional[Dict]): specifies configuration
33+
for logging, if None, it is loaded from the logging.yaml
34+
ensemble_size (int), (default=50): Number of models added to the ensemble built by
35+
Ensemble selection from libraries of models.
36+
Models are drawn with replacement.
37+
ensemble_nbest (int), (default=50): only consider the ensemble_nbest
38+
models to build the ensemble
39+
max_models_on_disc (int), (default=50): maximum number of models saved to disc.
40+
Also, controls the size of the ensemble as any additional models will be deleted.
41+
Must be greater than or equal to 1.
42+
temporary_directory (str): folder to store configuration output and log file
43+
output_directory (str): folder to store predictions for optional test set
44+
delete_tmp_folder_after_terminate (bool): determines whether to delete the temporary directory,
45+
when finished
46+
include_components (Optional[Dict]): If None, all possible components are used.
47+
Otherwise specifies set of components to use.
48+
exclude_components (Optional[Dict]): If None, all possible components are used.
49+
Otherwise specifies set of components not to use. Incompatible with include
50+
components
51+
"""
52+
53+
def __init__(
54+
self,
55+
seed: int = 1,
56+
n_jobs: int = 1,
57+
logging_config: Optional[Dict] = None,
58+
ensemble_size: int = 50,
59+
ensemble_nbest: int = 50,
60+
max_models_on_disc: int = 50,
61+
temporary_directory: Optional[str] = None,
62+
output_directory: Optional[str] = None,
63+
delete_tmp_folder_after_terminate: bool = True,
64+
delete_output_folder_after_terminate: bool = True,
65+
include_components: Optional[Dict] = None,
66+
exclude_components: Optional[Dict] = None,
67+
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
68+
resampling_strategy_args: Optional[Dict[str, Any]] = None,
69+
backend: Optional[Backend] = None,
70+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
71+
):
72+
super().__init__(
73+
seed=seed,
74+
n_jobs=n_jobs,
75+
logging_config=logging_config,
76+
ensemble_size=ensemble_size,
77+
ensemble_nbest=ensemble_nbest,
78+
max_models_on_disc=max_models_on_disc,
79+
temporary_directory=temporary_directory,
80+
output_directory=output_directory,
81+
delete_tmp_folder_after_terminate=delete_tmp_folder_after_terminate,
82+
delete_output_folder_after_terminate=delete_output_folder_after_terminate,
83+
include_components=include_components,
84+
exclude_components=exclude_components,
85+
backend=backend,
86+
resampling_strategy=resampling_strategy,
87+
resampling_strategy_args=resampling_strategy_args,
88+
search_space_updates=search_space_updates,
89+
task_type=TASK_TYPES_TO_STRING[TABULAR_REGRESSION],
90+
)
91+
92+
def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]:
93+
if not isinstance(dataset, TabularDataset):
94+
raise ValueError("Dataset is incompatible for the given task,: {}".format(
95+
type(dataset)
96+
))
97+
return {'task_type': dataset.task_type,
98+
'output_type': dataset.output_type,
99+
'issparse': dataset.issparse,
100+
'numerical_columns': dataset.numerical_columns,
101+
'categorical_columns': dataset.categorical_columns}
102+
103+
def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularRegressionPipeline:
104+
return TabularRegressionPipeline(dataset_properties=dataset_properties)
105+
106+
def search(self,
107+
optimize_metric: str,
108+
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
109+
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
110+
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
111+
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
112+
dataset_name: Optional[str] = None,
113+
budget_type: Optional[str] = None,
114+
budget: Optional[float] = None,
115+
total_walltime_limit: int = 100,
116+
func_eval_time_limit: int = 60,
117+
traditional_per_total_budget: float = 0.1,
118+
memory_limit: Optional[int] = 4096,
119+
smac_scenario_args: Optional[Dict[str, Any]] = None,
120+
get_smac_object_callback: Optional[Callable] = None,
121+
all_supported_metrics: bool = True,
122+
precision: int = 32,
123+
disable_file_output: List = [],
124+
load_models: bool = True,
125+
) -> 'BaseTask':
126+
"""
127+
Search for the best pipeline configuration for the given dataset.
128+
129+
Fit both optimizes the machine learning models and builds an ensemble out of them.
130+
To disable ensembling, set ensemble_size==0.
131+
using the optimizer.
132+
Args:
133+
X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame]
134+
A pair of features (X_train) and targets (y_train) used to fit a
135+
pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
136+
be provided to track the generalization performance of each stage.
137+
optimize_metric (str): name of the metric that is used to
138+
evaluate a pipeline.
139+
budget_type (Optional[str]):
140+
Type of budget to be used when fitting the pipeline.
141+
Either 'epochs' or 'runtime'. If not provided, uses
142+
the default in the pipeline config ('epochs')
143+
budget (Optional[float]):
144+
Budget to fit a single run of the pipeline. If not
145+
provided, uses the default in the pipeline config
146+
total_walltime_limit (int), (default=100): Time limit
147+
in seconds for the search of appropriate models.
148+
By increasing this value, autopytorch has a higher
149+
chance of finding better models.
150+
func_eval_time_limit (int), (default=60): Time limit
151+
for a single call to the machine learning model.
152+
Model fitting will be terminated if the machine
153+
learning algorithm runs over the time limit. Set
154+
this value high enough so that typical machine
155+
learning algorithms can be fit on the training
156+
data.
157+
traditional_per_total_budget (float), (default=0.1):
158+
Percent of total walltime to be allocated for
159+
running traditional classifiers.
160+
memory_limit (Optional[int]), (default=4096): Memory
161+
limit in MB for the machine learning algorithm. autopytorch
162+
will stop fitting the machine learning algorithm if it tries
163+
to allocate more than memory_limit MB. If None is provided,
164+
no memory limit is set. In case of multi-processing, memory_limit
165+
will be per job. This memory limit also applies to the ensemble
166+
creation process.
167+
smac_scenario_args (Optional[Dict]): Additional arguments inserted
168+
into the scenario of SMAC. See the
169+
[SMAC documentation] (https://automl.github.io/SMAC3/master/options.html?highlight=scenario#scenario)
170+
get_smac_object_callback (Optional[Callable]): Callback function
171+
to create an object of class
172+
[smac.optimizer.smbo.SMBO](https://automl.github.io/SMAC3/master/apidoc/smac.optimizer.smbo.html).
173+
The function must accept the arguments scenario_dict,
174+
instances, num_params, runhistory, seed and ta. This is
175+
an advanced feature. Use only if you are familiar with
176+
[SMAC](https://automl.github.io/SMAC3/master/index.html).
177+
all_supported_metrics (bool), (default=True): if True, all
178+
metrics supporting current task will be calculated
179+
for each pipeline and results will be available via cv_results
180+
precision (int), (default=32): Numeric precision used when loading
181+
ensemble data. Can be either '16', '32' or '64'.
182+
disable_file_output (Union[bool, List]):
183+
load_models (bool), (default=True): Whether to load the
184+
models after fitting AutoPyTorch.
185+
186+
Returns:
187+
self
188+
189+
"""
190+
if dataset_name is None:
191+
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
192+
193+
# we have to create a logger for at this point for the validator
194+
self._logger = self._get_logger(dataset_name)
195+
196+
# Create a validator object to make sure that the data provided by
197+
# the user matches the autopytorch requirements
198+
self.InputValidator = TabularInputValidator(
199+
is_classification=False,
200+
logger_port=self._logger_port,
201+
)
202+
203+
# Fit a input validator to check the provided data
204+
# Also, an encoder is fit to both train and test data,
205+
# to prevent unseen categories during inference
206+
self.InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
207+
208+
self.dataset = TabularDataset(
209+
X=X_train, Y=y_train,
210+
X_test=X_test, Y_test=y_test,
211+
validator=self.InputValidator,
212+
resampling_strategy=self.resampling_strategy,
213+
resampling_strategy_args=self.resampling_strategy_args,
214+
)
215+
216+
return self._search(
217+
dataset=self.dataset,
218+
optimize_metric=optimize_metric,
219+
budget_type=budget_type,
220+
budget=budget,
221+
total_walltime_limit=total_walltime_limit,
222+
func_eval_time_limit=func_eval_time_limit,
223+
traditional_per_total_budget=traditional_per_total_budget,
224+
memory_limit=memory_limit,
225+
smac_scenario_args=smac_scenario_args,
226+
get_smac_object_callback=get_smac_object_callback,
227+
all_supported_metrics=all_supported_metrics,
228+
precision=precision,
229+
disable_file_output=disable_file_output,
230+
load_models=load_models,
231+
)
232+
233+
def predict(
234+
self,
235+
X_test: np.ndarray,
236+
batch_size: Optional[int] = None,
237+
n_jobs: int = 1
238+
) -> np.ndarray:
239+
if self.InputValidator is None or not self.InputValidator._is_fitted:
240+
raise ValueError("predict() is only supported after calling search. Kindly call first "
241+
"the estimator fit() method.")
242+
243+
X_test = self.InputValidator.feature_validator.transform(X_test)
244+
predicted_values = super().predict(X_test, batch_size=batch_size,
245+
n_jobs=n_jobs)
246+
247+
# Allow to predict in the original domain -- that is, the user is not interested
248+
# in our encoded values
249+
return self.InputValidator.target_validator.inverse_transform(predicted_values)

autoPyTorch/datasets/base_dataset.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torchvision
1313

14+
from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
1415
from autoPyTorch.datasets.resampling_strategy import (
1516
CROSS_VAL_FN,
1617
CrossValTypes,
@@ -113,11 +114,15 @@ def __init__(
113114
self.resampling_strategy_args = resampling_strategy_args
114115
self.task_type: Optional[str] = None
115116
self.issparse: bool = issparse(self.train_tensors[0])
116-
self.input_shape: Tuple[int] = train_tensors[0].shape[1:]
117-
self.num_classes: Optional[int] = None
118-
if len(train_tensors) == 2 and train_tensors[1] is not None:
117+
self.input_shape: Tuple[int] = self.train_tensors[0].shape[1:]
118+
119+
if len(self.train_tensors) == 2 and self.train_tensors[1] is not None:
119120
self.output_type: str = type_of_target(self.train_tensors[1])
120-
self.output_shape: int = train_tensors[1].shape[1] if train_tensors[1].shape == 2 else 1
121+
122+
if STRING_TO_OUTPUT_TYPES[self.output_type] in CLASSIFICATION_OUTPUTS:
123+
self.output_shape = len(np.unique(self.train_tensors[1]))
124+
else:
125+
self.output_shape = self.train_tensors[1].shape[-1] if self.train_tensors[1].ndim > 1 else 1
121126

122127
# TODO: Look for a criteria to define small enough to preprocess
123128
self.is_small_preprocess = True
@@ -368,8 +373,7 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
368373
'output_type': self.output_type,
369374
'issparse': self.issparse,
370375
'input_shape': self.input_shape,
371-
'output_shape': self.output_shape,
372-
'num_classes': self.num_classes,
376+
'output_shape': self.output_shape
373377
})
374378
return dataset_properties
375379

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from sklearn.base import BaseEstimator
1414
from sklearn.dummy import DummyClassifier, DummyRegressor
15-
from sklearn.ensemble import VotingClassifier, VotingRegressor
15+
from sklearn.ensemble import VotingClassifier
1616

1717
from smac.tae import StatusType
1818

@@ -32,6 +32,7 @@
3232
from autoPyTorch.datasets.base_dataset import BaseDataset
3333
from autoPyTorch.datasets.tabular_dataset import TabularDataset
3434
from autoPyTorch.evaluation.utils import (
35+
VotingRegressorWrapper,
3536
convert_multioutput_multiclass_to_multilabel
3637
)
3738
from autoPyTorch.pipeline.base_pipeline import BasePipeline
@@ -513,7 +514,7 @@ def file_output(
513514
if self.task_type in CLASSIFICATION_TASKS:
514515
pipelines = VotingClassifier(estimators=None, voting='soft', )
515516
else:
516-
pipelines = VotingRegressor(estimators=None)
517+
pipelines = VotingRegressorWrapper(estimators=None)
517518
pipelines.estimators_ = self.pipelines
518519
else:
519520
pipelines = None

autoPyTorch/evaluation/train_evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def _predict(self, pipeline: BaseEstimator,
297297
self.y_valid)
298298
else:
299299
valid_pred = None
300+
300301
if self.X_test is not None:
301302
test_pred = self.predict_function(self.X_test, pipeline,
302303
self.y_train[train_indices])

0 commit comments

Comments
 (0)