Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time series forecasting #434

Merged
merged 357 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
357 commits
Select commit Hold shift + click to select a range
f193423
new target scaler, allow NoNorm for MLP Encpder
dengdifan Dec 22, 2021
752a58f
allow sampling full sequences
dengdifan Dec 22, 2021
2ab286b
integrate SeqBuilder to SequenceCollector
dengdifan Dec 22, 2021
d90d630
restore SequenceBuilder to reduce memory usage
dengdifan Dec 22, 2021
adcf8a0
move scaler to network
dengdifan Dec 22, 2021
c3da078
lag sequence
dengdifan Dec 22, 2021
5d67973
merge encoder and decoder as a single pipeline
dengdifan Dec 30, 2021
45d2078
faster lag_seq builder
dengdifan Jan 3, 2022
e4c5358
maint
dengdifan Jan 4, 2022
a0a01b3
new init, faster DeepAR inference in trainer
dengdifan Jan 5, 2022
27e0eb0
more losses types
dengdifan Jan 6, 2022
3710cb2
maint
dengdifan Jan 6, 2022
70d82d9
new Transformer models, allow RNN to do deepAR inference
dengdifan Jan 9, 2022
aa0221a
maint
dengdifan Jan 12, 2022
0a8b5f2
maint
dengdifan Jan 12, 2022
7a7b68d
maint
dengdifan Jan 13, 2022
eeeda02
maint
dengdifan Jan 13, 2022
738f18d
reduced search space for Transformer
dengdifan Jan 13, 2022
353b5c5
reduced init design
dengdifan Jan 13, 2022
e8db57b
maint
dengdifan Jan 13, 2022
8dec08c
maint
dengdifan Jan 14, 2022
7008cec
maint
dengdifan Jan 14, 2022
c7d401e
maint
dengdifan Jan 14, 2022
df15a2b
faster forecasting
dengdifan Jan 16, 2022
ca6a47d
maint
dengdifan Jan 16, 2022
5c9b10c
allow singel fidelity
dengdifan Jan 16, 2022
64580f7
maint
dengdifan Jan 17, 2022
df69c79
fix budget num_seq
dengdifan Jan 17, 2022
91acd5b
faster sampler and lagger
dengdifan Jan 17, 2022
fa96b27
maint
dengdifan Jan 17, 2022
995684a
maint
dengdifan Jan 17, 2022
d3a0e31
maint deepAR
dengdifan Jan 18, 2022
d8b3892
maint
dengdifan Jan 18, 2022
5e9fbae
maint
dengdifan Jan 19, 2022
3eb197e
cross validation
dengdifan Jan 19, 2022
6f4fdf1
allow holdout for smaller datasets
dengdifan Jan 19, 2022
8de8e50
smac4ac to smac4hpo
dengdifan Jan 19, 2022
54fbe59
maint
dengdifan Jan 21, 2022
918776f
maint
dengdifan Jan 21, 2022
8570951
allow to change decoder search space
dengdifan Jan 21, 2022
90edcfb
more resampling strategy, more options for MLP
dengdifan Jan 24, 2022
df7f568
reduced NBEATS
dengdifan Jan 24, 2022
6a22e4d
subsampler for val loader
dengdifan Jan 25, 2022
01ee1f1
rng for dataloader sampler
dengdifan Jan 25, 2022
5dcdc4e
maint
dengdifan Jan 26, 2022
f32a12b
remove generator as it cannot be pickled
dengdifan Jan 26, 2022
344e7df
allow lower fidelity to evaluate less test instances
dengdifan Jan 28, 2022
94891f2
fix dummy forecastro isues
dengdifan Jan 30, 2022
8fc51b1
maint
dengdifan Jan 30, 2022
0b34683
add gluonts as requirement
dengdifan Jan 30, 2022
f23840b
more data for val set for larger dataset
dengdifan Feb 1, 2022
dc84904
maint
dengdifan Feb 2, 2022
6466d11
Merge branch 'refactor_development_time_series' of https://github.com…
dengdifan Feb 2, 2022
6981553
maint
dengdifan Feb 3, 2022
3fda94a
fix nbeats decoder
Feb 14, 2022
d95e230
new dataset interface
dengdifan Feb 16, 2022
d185609
Merge branch 'refactor_development_time_series' of https://github.com…
dengdifan Feb 16, 2022
d5459fa
resolve conflict
dengdifan Feb 16, 2022
510cc5a
maint
dengdifan Feb 16, 2022
3806fe2
allow encoder to receive input from different sources
dengdifan Feb 16, 2022
9251bbc
multi blocks hp design
dengdifan Feb 18, 2022
5617db6
maint
dengdifan Feb 20, 2022
d04cb04
correct hp updates
dengdifan Feb 20, 2022
7881bb5
first trial on nested conjunction
dengdifan Feb 21, 2022
d7bff6e
maint
dengdifan Feb 21, 2022
2153bc2
fit for deep AR model (needs to be reverted when the issue in ConfigS…
dengdifan Feb 21, 2022
b2063e7
adjust backbones to fit new structure
dengdifan Feb 23, 2022
59cee13
further API changes
dengdifan Feb 28, 2022
b2b5580
tft temporal fusion decoder
dengdifan Feb 28, 2022
57461b9
construct network
dengdifan Mar 2, 2022
20eb852
cells for networks
dengdifan Mar 2, 2022
f5cede7
forecasting backbones
dengdifan Mar 4, 2022
50c559e
maint
dengdifan Mar 4, 2022
2dd0b11
maint
dengdifan Mar 6, 2022
0f0dbf0
move tft layer to backbone
dengdifan Mar 7, 2022
9e68629
maint
dengdifan Mar 7, 2022
ed99ba1
quantile loss
dengdifan Mar 7, 2022
45535ba
maint
dengdifan Mar 8, 2022
9fac9fe
maint
dengdifan Mar 8, 2022
75570c2
maint
dengdifan Mar 8, 2022
2f954cd
maint
dengdifan Mar 8, 2022
31f8ddc
maint
dengdifan Mar 8, 2022
7f4911e
maint
dengdifan Mar 8, 2022
2e31fdb
forecasting init configs
dengdifan Mar 8, 2022
125921c
add forbidden
dengdifan Mar 9, 2022
8d704d1
maint
dengdifan Mar 10, 2022
e646672
maint
dengdifan Mar 10, 2022
a2ad3fe
maint
dengdifan Mar 10, 2022
200691c
remove shift data
dengdifan Mar 11, 2022
538f24e
maint
dengdifan Mar 11, 2022
12ccf4b
maint
dengdifan Mar 11, 2022
4d6853d
copy dataset_properties for each refit iteration
dengdifan Mar 11, 2022
34d556a
maint and new init
dengdifan Mar 14, 2022
37501ef
Tft forecating with features (#6)
dengdifan Mar 16, 2022
5746541
fix loss computation in QuantileLoss
dengdifan Mar 16, 2022
b1fbece
fixed scaler computation
dengdifan Mar 18, 2022
683ccf5
maint
dengdifan Mar 19, 2022
95d2ab5
fix dataset
Mar 22, 2022
baaf34f
adjust window_size to seasonality
Mar 22, 2022
897cd74
maint scaling
dengdifan Mar 23, 2022
a09ddbb
fix uncorrect Seq2Seq scaling
dengdifan Mar 23, 2022
c1dda0a
fix sampling for seq2seq
dengdifan Mar 25, 2022
49ee49c
maint
dengdifan Mar 25, 2022
dc97df2
fix scaling in NBEATS
dengdifan Mar 25, 2022
399572c
move time feature computation to dataset
dengdifan Mar 28, 2022
7154308
maint
dengdifan Mar 30, 2022
1ba08fe
fix feature computation
dengdifan Mar 31, 2022
04a69d8
maint
dengdifan Mar 31, 2022
471db34
multi-variant feature validator
dengdifan Apr 13, 2022
16cf754
resolve conflicts
dengdifan Apr 13, 2022
cc77b51
maint
dengdifan Apr 13, 2022
fe6fb1f
validator for multi-variant series
dengdifan Apr 13, 2022
9264f89
feature validator
dengdifan Apr 14, 2022
aa3f7a6
multi-variant datasets
dengdifan Apr 14, 2022
974f8ff
observed targets
dengdifan Apr 14, 2022
37dd821
stucture adjustment
dengdifan Apr 20, 2022
1a6e19d
refactory ts tasks and preprocessing
dengdifan Apr 22, 2022
075c6e6
allow nan in targets
dengdifan Apr 22, 2022
2487117
preprocessing for time series
dengdifan Apr 22, 2022
86e4e3c
maint
dengdifan Apr 25, 2022
2c9944c
forecasting pipeline
dengdifan Apr 25, 2022
7eb5139
maint
dengdifan Apr 26, 2022
22fc0bc
embedding and maint
dengdifan Apr 26, 2022
1759fdf
move targets to the tail of the features
dengdifan Apr 26, 2022
9652c80
maint
dengdifan Apr 26, 2022
1d89636
static features
dengdifan Apr 27, 2022
282d63b
adjsut scaler to static features
dengdifan Apr 27, 2022
fb8b805
remove static features from forward dict
dengdifan Apr 27, 2022
533f12d
test transform
dengdifan Apr 27, 2022
f8be97c
maint
dengdifan Apr 28, 2022
e8c9071
test sets
dengdifan Apr 28, 2022
2779015
adjust dataset to allow future known features
dengdifan Apr 29, 2022
1a1fe68
maint
dengdifan Apr 29, 2022
f4ad355
maint
dengdifan Apr 29, 2022
79ef7a7
flake8
dengdifan Apr 29, 2022
88977e0
synchronise with development
dengdifan May 2, 2022
b269ff8
recover timeseries
dengdifan May 2, 2022
31f9e43
maint
dengdifan May 2, 2022
67ea836
maint
dengdifan May 2, 2022
80b8ac2
limit memory usage tae
dengdifan May 2, 2022
d01e2a7
revert test api
dengdifan May 2, 2022
3be7be9
test for targets
dengdifan May 3, 2022
77dcb7c
not allow sparse forecasting target
dengdifan May 3, 2022
6932199
test for data validator
dengdifan May 4, 2022
ee97108
test for validations
dengdifan May 5, 2022
b7f51f2
test on TimeSeriesSequence
dengdifan May 5, 2022
08bfe18
maint
dengdifan May 5, 2022
478ad68
test for resampling
dengdifan May 6, 2022
1986593
test for dataset 1
dengdifan May 7, 2022
112c876
test for datasets
dengdifan May 8, 2022
235e310
test on tae
dengdifan May 9, 2022
9d8dd0b
maint
dengdifan May 9, 2022
dc4b602
all evaluator to evalaute test sets
dengdifan May 10, 2022
e8cf8cb
tests on losses
dengdifan May 10, 2022
e5b1c47
test for metrics
dengdifan May 10, 2022
3f47489
forecasting preprocessing
dengdifan May 10, 2022
835055d
maint
dengdifan May 11, 2022
ef9e44e
finish test for preprocessing
dengdifan May 11, 2022
21b3958
test for data loader
dengdifan May 12, 2022
101ddbc
tests for dataloader
dengdifan May 13, 2022
7318086
maint
dengdifan May 13, 2022
cf2c982
test for target scaling 1
dengdifan May 13, 2022
8b7ef61
test for target scaer
dengdifan May 15, 2022
1025b93
test for training loss
dengdifan May 15, 2022
6f68633
maint
dengdifan May 16, 2022
570408d
test for network backbone
dengdifan May 16, 2022
7d42007
test for backbone base
dengdifan May 17, 2022
2033075
test for flat encoder
dengdifan May 17, 2022
c6e2239
test for seq encoder
dengdifan May 17, 2022
727e48e
test for seqencoder
dengdifan May 18, 2022
23dde67
maint
dengdifan May 18, 2022
4d9fe30
test for recurrent decoders
dengdifan May 19, 2022
eb5a7ec
test for network
dengdifan May 19, 2022
0ea372e
maint
dengdifan May 19, 2022
1b7ebbe
test for architecture
dengdifan May 20, 2022
f055fd5
test for pipelines
dengdifan May 20, 2022
ccab50e
fixed sampler
dengdifan May 21, 2022
54acaa6
maint sampler
dengdifan May 21, 2022
da6e92d
resolve conflict between embedding and net encoder
dengdifan May 21, 2022
fba012c
fix scaling
dengdifan May 21, 2022
2ed1197
allow transform for test dataloader
dengdifan May 21, 2022
95eb783
maint dataloader
dengdifan May 21, 2022
8035221
fix updates
dengdifan May 22, 2022
f3cb2de
fix dataset
dengdifan May 23, 2022
0af1217
tests on api, initial design on multi-variant
dengdifan May 24, 2022
c717fae
maint
dengdifan May 24, 2022
78d7a51
fix dataloader
dengdifan May 24, 2022
fa5cc75
move test with for loop to unittest.subtest
dengdifan May 24, 2022
2d2e039
flake 8 and update requirement
dengdifan May 24, 2022
a1c7930
mypy
dengdifan May 24, 2022
ba96c37
validator for pd dataframe
dengdifan May 27, 2022
cdcdb5a
allow series idx for api
dengdifan May 27, 2022
43671dd
maint
dengdifan May 30, 2022
806afb3
examples for forecasting
dengdifan May 30, 2022
bc80bf1
fix mypy
dengdifan May 30, 2022
c584a58
properly memory limitation for forecasting example
dengdifan May 30, 2022
0e37178
fix pre-commit
dengdifan May 30, 2022
1cf31b2
maint dataloader
dengdifan May 31, 2022
a8fa53c
remove unused auto-regressive arguments
dengdifan May 31, 2022
a8bd54d
fix pre-commit
dengdifan May 31, 2022
609ccf1
maint
dengdifan May 31, 2022
168b7cf
maint mypy
dengdifan May 31, 2022
88c2354
mypy!!!
dengdifan May 31, 2022
374cc1d
pre-commit
dengdifan May 31, 2022
4898ca5
mypyyyyyyyyyyyyyyyyyyyyyyyy
dengdifan May 31, 2022
694eebb
maint
dengdifan Jun 13, 2022
abd3900
move forcasting requirements to extras_require
dengdifan Jun 13, 2022
776aa84
bring eval_test to tae
dengdifan Jun 14, 2022
f70e2b3
make rh2epm consistent with SMAC4HPO
dengdifan Jun 14, 2022
50f6f18
remove smac4ac from smbo
dengdifan Jun 14, 2022
2663ad9
revert changes in network
dengdifan Jun 14, 2022
58eeb0c
revert changes in trainer
dengdifan Jun 14, 2022
b86908f
revert format changes
dengdifan Jun 14, 2022
68d8a25
move constant_forecasting to constatn
dengdifan Jun 14, 2022
dac5cdd
additional annotate for base pipeline
dengdifan Jun 14, 2022
7f2d394
move forecasting check to tae
dengdifan Jun 14, 2022
e43d70a
maint time series refit dataset
dengdifan Jun 14, 2022
dc48b9d
fix test
dengdifan Jun 14, 2022
1e7253a
workflow for extra requirements
dengdifan Jun 14, 2022
83e2469
docs for time series dataset
dengdifan Jun 14, 2022
1671992
fix pre-commit
dengdifan Jun 14, 2022
97d3835
docs for dataset
dengdifan Jun 14, 2022
889c5e9
maint docstring
dengdifan Jun 14, 2022
f68dc18
merge target scaler to one file
dengdifan Jun 14, 2022
dc4f510
fix forecasting init cfgs
dengdifan Jun 14, 2022
951ef4e
remove redudant pipeline configs
dengdifan Jun 14, 2022
10f0c83
maint
dengdifan Jun 14, 2022
8574c6f
SMAC4HPO instead of SMAC4AC in smbo (will be reverted further if stud…
dengdifan Jun 15, 2022
86e39bc
fixed docstrign for RNN and Transformer Decoder
dengdifan Jun 15, 2022
21fbcb2
uniformed docstrings for smbo and base task
dengdifan Jun 15, 2022
ee66c25
correct encoder to decoder in decoder.init
dengdifan Jun 15, 2022
877a124
fix doc strings
dengdifan Jun 15, 2022
1d3a74e
add license and docstrings for NBEATS heads
dengdifan Jun 16, 2022
2516859
allow memory limit to be None
dengdifan Jun 16, 2022
fe5e587
relax test load for forecasting
dengdifan Jun 16, 2022
2c6f66f
fix docs
dengdifan Jun 16, 2022
bb7f5c5
fix pre-commit
dengdifan Jun 16, 2022
9d728b5
make test compatible with py37
dengdifan Jun 17, 2022
a331093
maint docstring
dengdifan Jun 17, 2022
8a5a91b
split forecasting_eval_train_function from eval_train_function
dengdifan Jun 17, 2022
acddd22
fix namespace for test_api from train_evaluator to tae
dengdifan Jun 17, 2022
b18ce92
maint test api for forecasting
dengdifan Jun 17, 2022
0700e61
decrease number of ensemble size of test_time_series_forecasting to r…
dengdifan Jun 17, 2022
e4328ee
flatten all the prediction for forecasting pipelines
dengdifan Jun 17, 2022
b6baef1
pre-commit fix
dengdifan Jun 17, 2022
c1de20f
Merge remote-tracking branch 'upstream/development' into time_series_…
dengdifan Jun 20, 2022
0771c8e
fix docstrings and typing
dengdifan Jun 20, 2022
d066fda
maint time series dataset docstrings
dengdifan Jun 22, 2022
f701df3
maint warning message in time_series_forecasting_train_evaluator
dengdifan Jun 22, 2022
5e970f6
fix lines that are overlength
dengdifan Jun 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
revert format changes
  • Loading branch information
dengdifan committed Jun 14, 2022
commit b86908fa0756d45d1f847ecd7057c1988c73312f
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Copyright (C) 2021 [AutoML Groups Freiburg and Hannover](http://www.automl.org/

While early AutoML frameworks focused on optimizing traditional ML pipelines and their hyperparameters, another trend in AutoML is to focus on neural architecture search. To bring the best of these two worlds together, we developed **Auto-PyTorch**, which jointly and robustly optimizes the network architecture and the training hyperparameters to enable fully automated deep learning (AutoDL).


Auto-PyTorch is mainly developed to support tabular data (classification, regression).
The newest features in Auto-PyTorch for tabular data are described in the paper ["Auto-PyTorch Tabular: Multi-Fidelity MetaLearning for Efficient and Robust AutoDL"](https://arxiv.org/abs/2006.13799) (see below for bibtex ref).

Expand Down
243 changes: 120 additions & 123 deletions autoPyTorch/api/base_task.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion autoPyTorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
TABULAR_TASKS = [TABULAR_CLASSIFICATION, TABULAR_REGRESSION]
IMAGE_TASKS = [IMAGE_CLASSIFICATION, IMAGE_REGRESSION]
TIMESERIES_TASKS = [TIMESERIES_FORECASTING]

TASK_TYPES = REGRESSION_TASKS + CLASSIFICATION_TASKS + FORECASTING_TASKS

TASK_TYPES_TO_STRING = \
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/constants_forecasting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# The cosntant values for time series forecasting comes from
# The constant values for time series forecasting comes from
# https://github.com/rakshitha123/TSForecasting/blob/master/experiments/deep_learning_experiments.py
# seasonality map, maps a frequency value to a number

Expand Down
32 changes: 17 additions & 15 deletions autoPyTorch/data/tabular_target_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from autoPyTorch.data.base_target_validator import BaseTargetValidator, SupportedTargetTypes
from autoPyTorch.utils.common import ispandas


ArrayType = Union[np.ndarray, spmatrix]


Expand Down Expand Up @@ -54,9 +55,9 @@ def _modify_regression_target(y: ArrayType, allow_nan: bool = False) -> ArrayTyp

class TabularTargetValidator(BaseTargetValidator):
def _fit(
self,
y_train: SupportedTargetTypes,
y_test: Optional[SupportedTargetTypes] = None,
self,
y_train: SupportedTargetTypes,
y_test: Optional[SupportedTargetTypes] = None,
) -> BaseEstimator:
"""
If dealing with classification, this utility encodes the targets.
Expand Down Expand Up @@ -93,10 +94,10 @@ def _fit(
unknown_value=-1)
else:
# We should not reach this if statement as we check for type of targets before
raise ValueError("Multi-dimensional classification is not yet supported. "
"Encoding multidimensional data converts multiple columns "
"to a 1 dimensional encoding. Data involved = {}/{}".format(np.shape(y_train),
self.type_of_target)
raise ValueError(f"Multi-dimensional classification is not yet supported. "
f"Encoding multidimensional data converts multiple columns "
f"to a 1 dimensional encoding. Data involved = "
f"{np.shape(y_train)}/{self.type_of_target}"
)

# Mypy redefinition
Expand All @@ -120,8 +121,8 @@ def _fit(
if is_numeric_dtype(y_train.dtype):
self.dtype = y_train.dtype
elif (
hasattr(y_train, 'dtypes')
and is_numeric_dtype(cast(pd.DataFrame, y_train).dtypes[0])
hasattr(y_train, 'dtypes')
and is_numeric_dtype(cast(pd.DataFrame, y_train).dtypes[0])
):
# This case is for pandas array with a single column
y_train = cast(pd.DataFrame, y_train)
Expand Down Expand Up @@ -224,12 +225,13 @@ def _check_data(self, y: SupportedTargetTypes) -> None:
y (SupportedTargetTypes):
A set of features whose dimensionality and data type is going to be checked
"""

if not isinstance(y, (np.ndarray, pd.DataFrame,
List, pd.Series)) \
and not issparse(y): # type: ignore[misc]
raise ValueError("AutoPyTorch only supports Numpy arrays, Pandas DataFrames,"
" pd.Series, sparse data and Python Lists as targets, yet, "
"the provided input is of type {}".format(type(y))
raise ValueError(f"AutoPyTorch only supports Numpy arrays, Pandas DataFrames,"
f" pd.Series, sparse data and Python Lists as targets, yet, "
f"the provided input is of type {type(y)}"
)

# Sparse data muss be numerical
Expand Down Expand Up @@ -296,7 +298,7 @@ def _check_data(self, y: SupportedTargetTypes) -> None:
# should filter out unsupported types.
)
if self.type_of_target not in supported_output_types:
raise ValueError("Provided targets are not supported by AutoPyTorch. "
"Provided type is {} whereas supported types are {}.".format(self.type_of_target,
supported_output_types)
raise ValueError(f"Provided targets are not supported by AutoPyTorch. "
f"Provided type is {self.type_of_target} "
f"whereas supported types are {supported_output_types}."
)
24 changes: 14 additions & 10 deletions autoPyTorch/data/time_series_forecasting_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,19 @@ def fit( # type: ignore[override]
"""
fit the validator with the training data, (optionally) start times and other information
Args:
X_train (Optional[Union[List, pd.DataFrame]]): training features, could be None for "pure" forecasting tasks
y_train (Union[List, pd.DataFrame]), training targets
series_idx (Optional[Union[List[Union[str, int]], str, int]]): which columns of features are applied to
identify the series
X_test (Optional[Union[List, pd.DataFrame]]): test features. For forecasting tasks, test features indicates
known future features after the forecasting timestep\
y_test (Optional[Union[List, pd.DataFrame]]): target in the future
start_times (Optional[List[pd.DatetimeIndex]]): start times on which the first element of each series is
sampled
X_train (Optional[Union[List, pd.DataFrame]]):
training features, could be None for uni-variant forecasting tasks
y_train (Union[List, pd.DataFrame]),
training targets
series_idx (Optional[Union[List[Union[str, int]], str, int]])
which columns of features are applied to identify the series
X_test (Optional[Union[List, pd.DataFrame]]):
test features. For forecasting tasks, test features indicates known future features
after the forecasting timestep
y_test (Optional[Union[List, pd.DataFrame]]):
target in the future
start_times (Optional[List[pd.DatetimeIndex]]):
start times on which the first element of each series is sampled

"""
if series_idx is not None and not isinstance(series_idx, Iterable):
Expand Down Expand Up @@ -329,7 +333,7 @@ def join_series(
X: List[Union[pd.DataFrame, np.ndarray]], return_seq_lengths: bool = False
) -> Union[pd.DataFrame, Tuple[pd.DataFrame, List[int]]]:
"""
join the series into one single value
join the series into one single item
"""
num_sequences = len(X)
sequence_lengths = [0] * num_sequences
Expand Down
2 changes: 2 additions & 0 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def time_series_hold_out_validation(random_state: np.random.RandomState,

@classmethod
def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str, HoldOutFunc]:

holdout_validators = {
holdout_val_type.name: getattr(cls, holdout_val_type.name)
for holdout_val_type in holdout_val_types
Expand Down Expand Up @@ -228,6 +229,7 @@ def stratified_k_fold_cross_validation(random_state: np.random.RandomState,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:

shuffle = kwargs.get('shuffle', True)
cv = StratifiedKFold(n_splits=num_splits, shuffle=shuffle,
random_state=random_state if not shuffle else None)
Expand Down
76 changes: 40 additions & 36 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,36 @@
import autoPyTorch.pipeline.traditional_tabular_classification
import autoPyTorch.pipeline.traditional_tabular_regression
from autoPyTorch.automl_common.common.utils.backend import Backend
from autoPyTorch.constants import (CLASSIFICATION_TASKS, FORECASTING_TASKS,
IMAGE_TASKS, MULTICLASS, REGRESSION_TASKS,
STRING_TO_OUTPUT_TYPES,
STRING_TO_TASK_TYPES, TABULAR_TASKS)
from autoPyTorch.constants import (
CLASSIFICATION_TASKS,
FORECASTING_TASKS,
IMAGE_TASKS,
MULTICLASS,
REGRESSION_TASKS,
STRING_TO_OUTPUT_TYPES,
STRING_TO_TASK_TYPES,
TABULAR_TASKS
)
from autoPyTorch.constants_forecasting import FORECASTING_BUDGET_TYPE
from autoPyTorch.datasets.base_dataset import (BaseDataset,
BaseDatasetPropertiesType)
from autoPyTorch.datasets.base_dataset import (
BaseDataset,
BaseDatasetPropertiesType
)
from autoPyTorch.datasets.time_series_dataset import TimeSeriesSequence
from autoPyTorch.evaluation.utils import (
DisableFileOutputParameters, VotingRegressorWrapper,
convert_multioutput_multiclass_to_multilabel)
DisableFileOutputParameters,
VotingRegressorWrapper,
convert_multioutput_multiclass_to_multilabel
)
from autoPyTorch.pipeline.base_pipeline import BasePipeline
from autoPyTorch.pipeline.components.training.metrics.base import \
autoPyTorchMetric
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.pipeline.components.training.metrics.utils import (
calculate_loss, get_metrics)
calculate_loss,
get_metrics
)
from autoPyTorch.utils.common import dict_repr, subsampler
from autoPyTorch.utils.hyperparameter_search_space_update import \
HyperparameterSearchSpaceUpdates
from autoPyTorch.utils.logging_ import (PicklableClientLogger,
get_named_client_logger)
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger
from autoPyTorch.utils.pipeline import get_dataset_requirements

__all__ = [
Expand Down Expand Up @@ -141,7 +150,6 @@ class MyTraditionalTabularRegressionPipeline(BaseEstimator):
An optional dictionary that is passed to the pipeline's steps. It complies
a similar function as the kwargs
"""

def __init__(self, config: str,
dataset_properties: Dict[str, Any],
random_state: Optional[np.random.RandomState] = None,
Expand Down Expand Up @@ -185,7 +193,7 @@ def get_pipeline_representation(self) -> Dict[str, str]:

@staticmethod
def get_default_pipeline_options() -> Dict[str, Any]:
return autoPyTorch.pipeline.traditional_tabular_regression. \
return autoPyTorch.pipeline.traditional_tabular_regression.\
TraditionalTabularRegressionPipeline.get_default_pipeline_options()


Expand Down Expand Up @@ -448,7 +456,6 @@ class AbstractEvaluator(object):
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
An object used to fine tune the hyperparameter search space of the pipeline
"""

def __init__(self, backend: Backend,
queue: Queue,
metric: autoPyTorchMetric,
Expand All @@ -465,7 +472,7 @@ def __init__(self, backend: Backend,
init_params: Optional[Dict[str, Any]] = None,
logger_port: Optional[int] = None,
all_supported_metrics: bool = True,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
) -> None:

self.starttime = time.time()
Expand Down Expand Up @@ -494,7 +501,6 @@ def __init__(self, backend: Backend,
self.disable_file_output = disable_file_output

self.pipeline_class: Optional[Union[BaseEstimator, BasePipeline]] = None

if self.task_type in REGRESSION_TASKS:
if isinstance(self.configuration, int):
self.pipeline_class = DummyRegressionPipeline
Expand Down Expand Up @@ -572,7 +578,7 @@ def __init__(self, backend: Backend,
self.logger.debug("Search space updates :{}".format(self.search_space_updates))

def _init_datamanager_info(
self,
self,
) -> None:
"""
Initialises instance attributes that come from the datamanager.
Expand Down Expand Up @@ -619,10 +625,10 @@ def _init_datamanager_info(
del datamanager

def _init_fit_dictionary(
self,
logger_port: int,
pipeline_config: Dict[str, Any],
metrics_dict: Optional[Dict[str, List[str]]] = None,
self,
logger_port: int,
pipeline_config: Dict[str, Any],
metrics_dict: Optional[Dict[str, List[str]]] = None,
) -> None:
"""
Initialises the fit dictionary
Expand Down Expand Up @@ -680,7 +686,7 @@ def _init_fit_dictionary(
self.fit_dictionary.pop('runtime', None)
else:
raise ValueError(f"budget type must be `epochs` or `runtime` or {FORECASTING_BUDGET_TYPE} "
f"(Only used in forecasting taskss), but got {self.budget_type}")
f"(Only used by forecasting taskss), but got {self.budget_type}")

def _get_pipeline(self) -> BaseEstimator:
"""
Expand Down Expand Up @@ -837,10 +843,10 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
return None

def calculate_auxiliary_losses(
self,
Y_valid_pred: np.ndarray,
Y_test_pred: np.ndarray,
**metric_kwargs: Any
self,
Y_valid_pred: np.ndarray,
Y_test_pred: np.ndarray,
**metric_kwargs: Any
) -> Tuple[Optional[Dict[str, float]], Optional[Dict[str, float]]]:
"""
A helper function to calculate the performance estimate of the
Expand Down Expand Up @@ -877,10 +883,10 @@ def calculate_auxiliary_losses(
return validation_loss_dict, test_loss_dict

def file_output(
self,
Y_optimization_pred: np.ndarray,
Y_valid_pred: np.ndarray,
Y_test_pred: np.ndarray
self,
Y_optimization_pred: np.ndarray,
Y_valid_pred: np.ndarray,
Y_test_pred: np.ndarray
) -> Tuple[Optional[float], Dict]:
"""
This method decides what file outputs are written to disk.
Expand Down Expand Up @@ -1015,7 +1021,6 @@ def _predict_proba(self, X: np.ndarray, pipeline: BaseEstimator,
(np.ndarray):
The predictions of pipeline for the given features X
"""

@no_type_check
def send_warnings_to_log(message, category, filename, lineno,
file=None, line=None):
Expand Down Expand Up @@ -1050,7 +1055,6 @@ def _predict_regression(self, X: np.ndarray, pipeline: BaseEstimator,
(np.ndarray):
The predictions of pipeline for the given features X
"""

@no_type_check
def send_warnings_to_log(message, category, filename, lineno,
file=None, line=None):
Expand Down
Loading