Skip to content

Forecasting docs #442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:

- name: Install dependencies
run: |
pip install -e .[docs,examples]
pip install -e .[docs,examples,forecasting]

- name: Make docs
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/long_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Install test dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[test]
pip install -e .[forecasting,test]

- name: Run tests
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
run: |
git submodule update --init --recursive
python -m pip install --upgrade pip
pip install -e .[test]
pip install -e .[forecasting,test]

- name: Dist install
if: matrix.kind == 'dist'
Expand All @@ -98,7 +98,7 @@ jobs:

python setup.py sdist
last_dist=$(ls -t dist/autoPyTorch-*.tar.gz | head -n 1)
pip install $last_dist[test]
pip install $last_dist[forecasting,test]

- name: Store repository status
id: status-before
Expand Down
99 changes: 96 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ Copyright (C) 2021 [AutoML Groups Freiburg and Hannover](http://www.automl.org/

While early AutoML frameworks focused on optimizing traditional ML pipelines and their hyperparameters, another trend in AutoML is to focus on neural architecture search. To bring the best of these two worlds together, we developed **Auto-PyTorch**, which jointly and robustly optimizes the network architecture and the training hyperparameters to enable fully automated deep learning (AutoDL).

Auto-PyTorch is mainly developed to support tabular data (classification, regression).
Auto-PyTorch is mainly developed to support tabular data (classification, regression) and time series data (forecasting).
The newest features in Auto-PyTorch for tabular data are described in the paper ["Auto-PyTorch Tabular: Multi-Fidelity MetaLearning for Efficient and Robust AutoDL"](https://arxiv.org/abs/2006.13799) (see below for bibtex ref).
Details about Auto-PyTorch for multi-horizontal time series forecasting tasks can be found in the paper ["Efficient Automated Deep Learning for Time Series Forecasting"](https://arxiv.org/abs/2205.05511) (also see below for bibtex ref).

Also, find the documentation [here](https://automl.github.io/Auto-PyTorch/master).

Expand All @@ -27,7 +28,9 @@ In other words, we evaluate the portfolio on a provided data as initial configur
Then API starts the following procedures:
1. **Validate input data**: Process each data type, e.g. encoding categorical data, so that Auto-Pytorch can handled.
2. **Create dataset**: Create a dataset that can be handled in this API with a choice of cross validation or holdout splits.
3. **Evaluate baselines** *1: Train each algorithm in the predefined pool with a fixed hyperparameter configuration and dummy model from `sklearn.dummy` that represents the worst possible performance.
3. **Evaluate baselines**
* ***Tabular dataset*** *1: Train each algorithm in the predefined pool with a fixed hyperparameter configuration and dummy model from `sklearn.dummy` that represents the worst possible performance.
* ***Time Series Forecasting dataset*** : Train a dummy predictor that repeats the last observed value in each series
4. **Search by [SMAC](https://github.com/automl/SMAC3)**:\
a. Determine budget and cut-off rules by [Hyperband](https://jmlr.org/papers/volume18/16-558/16-558.pdf)\
b. Sample a pipeline hyperparameter configuration *2 by SMAC\
Expand All @@ -50,6 +53,14 @@ pip install autoPyTorch

```

Auto-PyTorch for Time Series Forecasting requires additional dependencies

```sh

pip install autoPyTorch[forecasting]

```

### Manual Installation

We recommend using Anaconda for developing as follows:
Expand All @@ -70,6 +81,20 @@ python setup.py install

```

Similarly, to install all the dependencies for Auto-PyTorch-TimeSeriesForecasting:


```sh

git submodule update --init --recursive

conda create -n auto-pytorch python=3.8
conda activate auto-pytorch
conda install swig
pip install -e[forecasting]

```

## Examples

In a nutshell:
Expand Down Expand Up @@ -105,6 +130,63 @@ score = api.score(y_pred, y_test)
print("Accuracy score", score)
```

For Time Series Forecasting Tasks
```py

from autoPyTorch.api.time_series_forecasting import TimeSeriesForecastingTask

# data and metric imports
from sktime.datasets import load_longley
targets, features = load_longley()

# define the forecasting horizon
forecasting_horizon = 3

# each series represent an element in the List
# we take the last forecasting_horizon as test targets. The itme before that as training targets
# Normally the value to be forecasted should follow the training sets
y_train = [targets[: -forecasting_horizon]]
y_test = [targets[-forecasting_horizon:]]

# same for features. For uni-variant models, X_train, X_test can be omitted
X_train = [features[: -forecasting_horizon]]
# Here x_test indicates the 'known future features': they are the features known previously, features that are unknown
# could be replaced with NAN or zeros (which will not be used by our networks). If no feature is known beforehand,
# we could also omit X_test
known_future_features = list(features.columns)
X_test = [features[-forecasting_horizon:]]

start_times = [targets.index.to_timestamp()[0]]
freq = '1Y'

# initialise Auto-PyTorch api
api = TimeSeriesForecastingTask()

# Search for an ensemble of machine learning algorithms
api.search(
X_train=X_train,
y_train=y_train,
X_test=X_test,
optimize_metric='mean_MAPE_forecasting',
n_prediction_steps=forecasting_horizon,
memory_limit=16 * 1024, # Currently, forecasting models need much more memories than it actually requires
freq=freq,
start_times=start_times,
func_eval_time_limit_secs=50,
total_walltime_limit=60,
min_num_test_instances=1000, # proxy validation sets. This only works for the tasks with more than 1000 series
known_future_features=known_future_features,
)

# our dataset could directly generate sequences for new datasets
test_sets = api.dataset.generate_test_seqs()

# Calculate test accuracy
y_pred = api.predict(test_sets)
score = api.score(y_pred, y_test)
print("Forecasting score", score)
```

For more examples including customising the search space, parellising the code, etc, checkout the `examples` folder

```sh
Expand Down Expand Up @@ -163,6 +245,17 @@ Please refer to the branch `TPAMI.2021.3067763` to reproduce the paper *Auto-PyT
}
```

```bibtex
@article{deng-ecml22,
author = {Difan Deng and Florian Karl and Frank Hutter and Bernd Bischl and Marius Lindauer},
title = {Efficient Automated Deep Learning for Time Series Forecasting},
year = {2022},
booktitle = {Machine Learning and Knowledge Discovery in Databases. Research Track
- European Conference, {ECML} {PKDD} 2022},
url = {https://doi.org/10.48550/arXiv.2205.05511},
}
```

## Contact

Auto-PyTorch is developed by the [AutoML Group of the University of Freiburg](http://www.automl.org/).
Auto-PyTorch is developed by the [AutoML Groups of the University of Freiburg and Hannover](http://www.automl.org/).
79 changes: 55 additions & 24 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@
from autoPyTorch import metrics
from autoPyTorch.automl_common.common.utils.backend import Backend, create
from autoPyTorch.constants import (
FORECASTING_BUDGET_TYPE,
FORECASTING_TASKS,
REGRESSION_TASKS,
STRING_TO_OUTPUT_TYPES,
STRING_TO_TASK_TYPES,
TIMESERIES_FORECASTING,
)
from autoPyTorch.data.base_validator import BaseInputValidator
from autoPyTorch.data.utils import DatasetCompressionSpec
Expand Down Expand Up @@ -77,7 +80,8 @@ def _pipeline_predict(pipeline: BasePipeline,
X: Union[np.ndarray, pd.DataFrame],
batch_size: int,
logger: PicklableClientLogger,
task: int) -> np.ndarray:
task: int,
task_type: str = "") -> np.ndarray:
@typing.no_type_check
def send_warnings_to_log(
message, category, filename, lineno, file=None, line=None):
Expand All @@ -87,7 +91,7 @@ def send_warnings_to_log(
X_ = X.copy()
with warnings.catch_warnings():
warnings.showwarning = send_warnings_to_log
if task in REGRESSION_TASKS:
if task in REGRESSION_TASKS or task in FORECASTING_TASKS:
# Voting regressor does not support batch size
prediction = pipeline.predict(X_)
else:
Expand All @@ -101,13 +105,13 @@ def send_warnings_to_log(
prediction,
np.sum(prediction, axis=1)
))

if len(prediction.shape) < 1 or len(X_.shape) < 1 or \
X_.shape[0] < 1 or prediction.shape[0] != X_.shape[0]:
logger.warning(
"Prediction shape for model %s is %s while X_.shape is %s",
pipeline, str(prediction.shape), str(X_.shape)
)
if STRING_TO_TASK_TYPES.get(task_type, -1) != TIMESERIES_FORECASTING:
if len(prediction.shape) < 1 or len(X_.shape) < 1 or \
X_.shape[0] < 1 or prediction.shape[0] != X_.shape[0]:
logger.warning(
"Prediction shape for model %s is %s while X_.shape is %s",
pipeline, str(prediction.shape), str(X_.shape)
)
return prediction


Expand Down Expand Up @@ -218,6 +222,8 @@ def __init__(
self.search_space: Optional[ConfigurationSpace] = None
self._dataset_requirements: Optional[List[FitRequirement]] = None
self._metric: Optional[autoPyTorchMetric] = None
self._metrics_kwargs: Dict = {}

self._scoring_functions: Optional[List[autoPyTorchMetric]] = None
self._logger: Optional[PicklableClientLogger] = None
self.dataset_name: Optional[str] = None
Expand Down Expand Up @@ -737,7 +743,7 @@ def _do_dummy_prediction(self) -> None:
stats=stats,
memory_limit=memory_limit,
disable_file_output=self._disable_file_output,
all_supported_metrics=self._all_supported_metrics
all_supported_metrics=self._all_supported_metrics,
)

status, _, _, additional_info = ta.run(num_run, cutoff=self._time_for_task)
Expand Down Expand Up @@ -822,7 +828,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
stats=stats,
memory_limit=memory_limit,
disable_file_output=self._disable_file_output,
all_supported_metrics=self._all_supported_metrics
all_supported_metrics=self._all_supported_metrics,
)
dask_futures.append([
classifier,
Expand Down Expand Up @@ -906,8 +912,8 @@ def _search(
optimize_metric: str,
dataset: BaseDataset,
budget_type: str = 'epochs',
min_budget: int = 5,
max_budget: int = 50,
min_budget: Union[int, float] = 5,
max_budget: Union[int, float] = 50,
total_walltime_limit: int = 100,
func_eval_time_limit_secs: Optional[int] = None,
enable_traditional_pipeline: bool = True,
Expand All @@ -920,7 +926,8 @@ def _search(
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
load_models: bool = True,
portfolio_selection: Optional[str] = None,
dask_client: Optional[dask.distributed.Client] = None
dask_client: Optional[dask.distributed.Client] = None,
**kwargs: Any
) -> 'BaseTask':
"""
Search for the best pipeline configuration for the given dataset.
Expand Down Expand Up @@ -1048,7 +1055,14 @@ def _search(
Additionally, the keyword 'greedy' is supported,
which would use the default portfolio from
`AutoPyTorch Tabular <https://arxiv.org/abs/2006.13799>`_

kwargs: Any
additional arguments that are customed by some specific task.
For instance, forecasting tasks require:
min_num_test_instances (int): minimal number of instances used to initialize a proxy validation set
suggested_init_models (List[str]): A set of initial models suggested by the users. Their
hyperparameters are determined by the default configurations
custom_init_setting_path (str): The path to the initial hyperparameter configurations set by
the users
Returns:
self

Expand Down Expand Up @@ -1110,7 +1124,10 @@ def _search(
self.search_space = self.get_search_space(dataset)

# Incorporate budget to pipeline config
if budget_type not in ('epochs', 'runtime'):
if budget_type not in ('epochs', 'runtime') and (
budget_type in FORECASTING_BUDGET_TYPE
and STRING_TO_TASK_TYPES[self.task_type] != TIMESERIES_FORECASTING
):
raise ValueError("Budget type must be one ('epochs', 'runtime')"
f" yet {budget_type} was provided")
self.pipeline_options['budget_type'] = budget_type
Expand Down Expand Up @@ -1216,6 +1233,7 @@ def _search(
precision=precision,
logger_port=self._logger_port,
pynisher_context=self._multiprocessing_context,
metrics_kwargs=self._metrics_kwargs,
)
self._stopwatch.stop_task(ensemble_task_name)

Expand All @@ -1229,7 +1247,6 @@ def _search(
if time_left_for_smac <= 0:
self._logger.warning(" Not starting SMAC because there is no time left")
else:

_proc_smac = AutoMLSMBO(
config_space=self.search_space,
dataset_name=str(dataset.dataset_name),
Expand Down Expand Up @@ -1259,6 +1276,8 @@ def _search(
search_space_updates=self.search_space_updates,
portfolio_selection=portfolio_selection,
pynisher_context=self._multiprocessing_context,
task_type=self.task_type,
**kwargs,
)
try:
run_history, self._results_manager.trajectory, budget_type = \
Expand Down Expand Up @@ -1323,19 +1342,30 @@ def _get_fit_dictionary(
dataset: BaseDataset,
split_id: int = 0
) -> Dict[str, Any]:
X_test = dataset.test_tensors[0].copy() if dataset.test_tensors is not None else None
y_test = dataset.test_tensors[1].copy() if dataset.test_tensors is not None else None
if dataset.test_tensors is not None:
X_test = dataset.test_tensors[0].copy() if dataset.test_tensors[0] is not None else None
y_test = dataset.test_tensors[1].copy() if dataset.test_tensors[1] is not None else None
else:
X_test = None
y_test = None

X_train = dataset.train_tensors[0].copy() if dataset.train_tensors[0] is not None else None
y_train = dataset.train_tensors[1].copy()
X: Dict[str, Any] = dict({'dataset_properties': dataset_properties,
'backend': self._backend,
'X_train': dataset.train_tensors[0].copy(),
'y_train': dataset.train_tensors[1].copy(),
'X_train': X_train,
'y_train': y_train,
'X_test': X_test,
'y_test': y_test,
'train_indices': dataset.splits[split_id][0],
'val_indices': dataset.splits[split_id][1],
'split_id': split_id,
'num_run': self._backend.get_next_num_run(),
})
if STRING_TO_TASK_TYPES[self.task_type] == TIMESERIES_FORECASTING:
warnings.warn("Currently Time Series Forecasting tasks do not allow computing metrics "
"during training. It will be automatically set as False")
self.pipeline_options["metrics_during_training"] = False
X.update(self.pipeline_options)
return X

Expand Down Expand Up @@ -1398,7 +1428,7 @@ def refit(
# could alleviate the problem in algorithms that depend on
# the ordering of the data.
X = self._get_fit_dictionary(
dataset_properties=dataset_properties,
dataset_properties=copy.copy(dataset_properties),
dataset=dataset,
split_id=split_id)
fit_and_suppress_warnings(self._logger, model, X, y=None)
Expand Down Expand Up @@ -1630,7 +1660,7 @@ def fit_pipeline(
exclude=exclude_components,
search_space_updates=search_space_updates,
pipeline_config=pipeline_options,
pynisher_context=self._multiprocessing_context
pynisher_context=self._multiprocessing_context,
)

run_info, run_value = tae.run_wrapper(
Expand Down Expand Up @@ -1722,7 +1752,8 @@ def predict(

all_predictions = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(_pipeline_predict)(
models[identifier], X_test, batch_size, self._logger, STRING_TO_TASK_TYPES[self.task_type]
models[identifier], X_test, batch_size, self._logger, STRING_TO_TASK_TYPES[self.task_type],
self.task_type
)
for identifier in self.ensemble_.get_selected_model_identifiers()
)
Expand Down
Loading