Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time series forecasting #434

Merged
merged 357 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
357 commits
Select commit Hold shift + click to select a range
f193423
new target scaler, allow NoNorm for MLP Encpder
dengdifan Dec 22, 2021
752a58f
allow sampling full sequences
dengdifan Dec 22, 2021
2ab286b
integrate SeqBuilder to SequenceCollector
dengdifan Dec 22, 2021
d90d630
restore SequenceBuilder to reduce memory usage
dengdifan Dec 22, 2021
adcf8a0
move scaler to network
dengdifan Dec 22, 2021
c3da078
lag sequence
dengdifan Dec 22, 2021
5d67973
merge encoder and decoder as a single pipeline
dengdifan Dec 30, 2021
45d2078
faster lag_seq builder
dengdifan Jan 3, 2022
e4c5358
maint
dengdifan Jan 4, 2022
a0a01b3
new init, faster DeepAR inference in trainer
dengdifan Jan 5, 2022
27e0eb0
more losses types
dengdifan Jan 6, 2022
3710cb2
maint
dengdifan Jan 6, 2022
70d82d9
new Transformer models, allow RNN to do deepAR inference
dengdifan Jan 9, 2022
aa0221a
maint
dengdifan Jan 12, 2022
0a8b5f2
maint
dengdifan Jan 12, 2022
7a7b68d
maint
dengdifan Jan 13, 2022
eeeda02
maint
dengdifan Jan 13, 2022
738f18d
reduced search space for Transformer
dengdifan Jan 13, 2022
353b5c5
reduced init design
dengdifan Jan 13, 2022
e8db57b
maint
dengdifan Jan 13, 2022
8dec08c
maint
dengdifan Jan 14, 2022
7008cec
maint
dengdifan Jan 14, 2022
c7d401e
maint
dengdifan Jan 14, 2022
df15a2b
faster forecasting
dengdifan Jan 16, 2022
ca6a47d
maint
dengdifan Jan 16, 2022
5c9b10c
allow singel fidelity
dengdifan Jan 16, 2022
64580f7
maint
dengdifan Jan 17, 2022
df69c79
fix budget num_seq
dengdifan Jan 17, 2022
91acd5b
faster sampler and lagger
dengdifan Jan 17, 2022
fa96b27
maint
dengdifan Jan 17, 2022
995684a
maint
dengdifan Jan 17, 2022
d3a0e31
maint deepAR
dengdifan Jan 18, 2022
d8b3892
maint
dengdifan Jan 18, 2022
5e9fbae
maint
dengdifan Jan 19, 2022
3eb197e
cross validation
dengdifan Jan 19, 2022
6f4fdf1
allow holdout for smaller datasets
dengdifan Jan 19, 2022
8de8e50
smac4ac to smac4hpo
dengdifan Jan 19, 2022
54fbe59
maint
dengdifan Jan 21, 2022
918776f
maint
dengdifan Jan 21, 2022
8570951
allow to change decoder search space
dengdifan Jan 21, 2022
90edcfb
more resampling strategy, more options for MLP
dengdifan Jan 24, 2022
df7f568
reduced NBEATS
dengdifan Jan 24, 2022
6a22e4d
subsampler for val loader
dengdifan Jan 25, 2022
01ee1f1
rng for dataloader sampler
dengdifan Jan 25, 2022
5dcdc4e
maint
dengdifan Jan 26, 2022
f32a12b
remove generator as it cannot be pickled
dengdifan Jan 26, 2022
344e7df
allow lower fidelity to evaluate less test instances
dengdifan Jan 28, 2022
94891f2
fix dummy forecastro isues
dengdifan Jan 30, 2022
8fc51b1
maint
dengdifan Jan 30, 2022
0b34683
add gluonts as requirement
dengdifan Jan 30, 2022
f23840b
more data for val set for larger dataset
dengdifan Feb 1, 2022
dc84904
maint
dengdifan Feb 2, 2022
6466d11
Merge branch 'refactor_development_time_series' of https://github.com…
dengdifan Feb 2, 2022
6981553
maint
dengdifan Feb 3, 2022
3fda94a
fix nbeats decoder
Feb 14, 2022
d95e230
new dataset interface
dengdifan Feb 16, 2022
d185609
Merge branch 'refactor_development_time_series' of https://github.com…
dengdifan Feb 16, 2022
d5459fa
resolve conflict
dengdifan Feb 16, 2022
510cc5a
maint
dengdifan Feb 16, 2022
3806fe2
allow encoder to receive input from different sources
dengdifan Feb 16, 2022
9251bbc
multi blocks hp design
dengdifan Feb 18, 2022
5617db6
maint
dengdifan Feb 20, 2022
d04cb04
correct hp updates
dengdifan Feb 20, 2022
7881bb5
first trial on nested conjunction
dengdifan Feb 21, 2022
d7bff6e
maint
dengdifan Feb 21, 2022
2153bc2
fit for deep AR model (needs to be reverted when the issue in ConfigS…
dengdifan Feb 21, 2022
b2063e7
adjust backbones to fit new structure
dengdifan Feb 23, 2022
59cee13
further API changes
dengdifan Feb 28, 2022
b2b5580
tft temporal fusion decoder
dengdifan Feb 28, 2022
57461b9
construct network
dengdifan Mar 2, 2022
20eb852
cells for networks
dengdifan Mar 2, 2022
f5cede7
forecasting backbones
dengdifan Mar 4, 2022
50c559e
maint
dengdifan Mar 4, 2022
2dd0b11
maint
dengdifan Mar 6, 2022
0f0dbf0
move tft layer to backbone
dengdifan Mar 7, 2022
9e68629
maint
dengdifan Mar 7, 2022
ed99ba1
quantile loss
dengdifan Mar 7, 2022
45535ba
maint
dengdifan Mar 8, 2022
9fac9fe
maint
dengdifan Mar 8, 2022
75570c2
maint
dengdifan Mar 8, 2022
2f954cd
maint
dengdifan Mar 8, 2022
31f8ddc
maint
dengdifan Mar 8, 2022
7f4911e
maint
dengdifan Mar 8, 2022
2e31fdb
forecasting init configs
dengdifan Mar 8, 2022
125921c
add forbidden
dengdifan Mar 9, 2022
8d704d1
maint
dengdifan Mar 10, 2022
e646672
maint
dengdifan Mar 10, 2022
a2ad3fe
maint
dengdifan Mar 10, 2022
200691c
remove shift data
dengdifan Mar 11, 2022
538f24e
maint
dengdifan Mar 11, 2022
12ccf4b
maint
dengdifan Mar 11, 2022
4d6853d
copy dataset_properties for each refit iteration
dengdifan Mar 11, 2022
34d556a
maint and new init
dengdifan Mar 14, 2022
37501ef
Tft forecating with features (#6)
dengdifan Mar 16, 2022
5746541
fix loss computation in QuantileLoss
dengdifan Mar 16, 2022
b1fbece
fixed scaler computation
dengdifan Mar 18, 2022
683ccf5
maint
dengdifan Mar 19, 2022
95d2ab5
fix dataset
Mar 22, 2022
baaf34f
adjust window_size to seasonality
Mar 22, 2022
897cd74
maint scaling
dengdifan Mar 23, 2022
a09ddbb
fix uncorrect Seq2Seq scaling
dengdifan Mar 23, 2022
c1dda0a
fix sampling for seq2seq
dengdifan Mar 25, 2022
49ee49c
maint
dengdifan Mar 25, 2022
dc97df2
fix scaling in NBEATS
dengdifan Mar 25, 2022
399572c
move time feature computation to dataset
dengdifan Mar 28, 2022
7154308
maint
dengdifan Mar 30, 2022
1ba08fe
fix feature computation
dengdifan Mar 31, 2022
04a69d8
maint
dengdifan Mar 31, 2022
471db34
multi-variant feature validator
dengdifan Apr 13, 2022
16cf754
resolve conflicts
dengdifan Apr 13, 2022
cc77b51
maint
dengdifan Apr 13, 2022
fe6fb1f
validator for multi-variant series
dengdifan Apr 13, 2022
9264f89
feature validator
dengdifan Apr 14, 2022
aa3f7a6
multi-variant datasets
dengdifan Apr 14, 2022
974f8ff
observed targets
dengdifan Apr 14, 2022
37dd821
stucture adjustment
dengdifan Apr 20, 2022
1a6e19d
refactory ts tasks and preprocessing
dengdifan Apr 22, 2022
075c6e6
allow nan in targets
dengdifan Apr 22, 2022
2487117
preprocessing for time series
dengdifan Apr 22, 2022
86e4e3c
maint
dengdifan Apr 25, 2022
2c9944c
forecasting pipeline
dengdifan Apr 25, 2022
7eb5139
maint
dengdifan Apr 26, 2022
22fc0bc
embedding and maint
dengdifan Apr 26, 2022
1759fdf
move targets to the tail of the features
dengdifan Apr 26, 2022
9652c80
maint
dengdifan Apr 26, 2022
1d89636
static features
dengdifan Apr 27, 2022
282d63b
adjsut scaler to static features
dengdifan Apr 27, 2022
fb8b805
remove static features from forward dict
dengdifan Apr 27, 2022
533f12d
test transform
dengdifan Apr 27, 2022
f8be97c
maint
dengdifan Apr 28, 2022
e8c9071
test sets
dengdifan Apr 28, 2022
2779015
adjust dataset to allow future known features
dengdifan Apr 29, 2022
1a1fe68
maint
dengdifan Apr 29, 2022
f4ad355
maint
dengdifan Apr 29, 2022
79ef7a7
flake8
dengdifan Apr 29, 2022
88977e0
synchronise with development
dengdifan May 2, 2022
b269ff8
recover timeseries
dengdifan May 2, 2022
31f9e43
maint
dengdifan May 2, 2022
67ea836
maint
dengdifan May 2, 2022
80b8ac2
limit memory usage tae
dengdifan May 2, 2022
d01e2a7
revert test api
dengdifan May 2, 2022
3be7be9
test for targets
dengdifan May 3, 2022
77dcb7c
not allow sparse forecasting target
dengdifan May 3, 2022
6932199
test for data validator
dengdifan May 4, 2022
ee97108
test for validations
dengdifan May 5, 2022
b7f51f2
test on TimeSeriesSequence
dengdifan May 5, 2022
08bfe18
maint
dengdifan May 5, 2022
478ad68
test for resampling
dengdifan May 6, 2022
1986593
test for dataset 1
dengdifan May 7, 2022
112c876
test for datasets
dengdifan May 8, 2022
235e310
test on tae
dengdifan May 9, 2022
9d8dd0b
maint
dengdifan May 9, 2022
dc4b602
all evaluator to evalaute test sets
dengdifan May 10, 2022
e8cf8cb
tests on losses
dengdifan May 10, 2022
e5b1c47
test for metrics
dengdifan May 10, 2022
3f47489
forecasting preprocessing
dengdifan May 10, 2022
835055d
maint
dengdifan May 11, 2022
ef9e44e
finish test for preprocessing
dengdifan May 11, 2022
21b3958
test for data loader
dengdifan May 12, 2022
101ddbc
tests for dataloader
dengdifan May 13, 2022
7318086
maint
dengdifan May 13, 2022
cf2c982
test for target scaling 1
dengdifan May 13, 2022
8b7ef61
test for target scaer
dengdifan May 15, 2022
1025b93
test for training loss
dengdifan May 15, 2022
6f68633
maint
dengdifan May 16, 2022
570408d
test for network backbone
dengdifan May 16, 2022
7d42007
test for backbone base
dengdifan May 17, 2022
2033075
test for flat encoder
dengdifan May 17, 2022
c6e2239
test for seq encoder
dengdifan May 17, 2022
727e48e
test for seqencoder
dengdifan May 18, 2022
23dde67
maint
dengdifan May 18, 2022
4d9fe30
test for recurrent decoders
dengdifan May 19, 2022
eb5a7ec
test for network
dengdifan May 19, 2022
0ea372e
maint
dengdifan May 19, 2022
1b7ebbe
test for architecture
dengdifan May 20, 2022
f055fd5
test for pipelines
dengdifan May 20, 2022
ccab50e
fixed sampler
dengdifan May 21, 2022
54acaa6
maint sampler
dengdifan May 21, 2022
da6e92d
resolve conflict between embedding and net encoder
dengdifan May 21, 2022
fba012c
fix scaling
dengdifan May 21, 2022
2ed1197
allow transform for test dataloader
dengdifan May 21, 2022
95eb783
maint dataloader
dengdifan May 21, 2022
8035221
fix updates
dengdifan May 22, 2022
f3cb2de
fix dataset
dengdifan May 23, 2022
0af1217
tests on api, initial design on multi-variant
dengdifan May 24, 2022
c717fae
maint
dengdifan May 24, 2022
78d7a51
fix dataloader
dengdifan May 24, 2022
fa5cc75
move test with for loop to unittest.subtest
dengdifan May 24, 2022
2d2e039
flake 8 and update requirement
dengdifan May 24, 2022
a1c7930
mypy
dengdifan May 24, 2022
ba96c37
validator for pd dataframe
dengdifan May 27, 2022
cdcdb5a
allow series idx for api
dengdifan May 27, 2022
43671dd
maint
dengdifan May 30, 2022
806afb3
examples for forecasting
dengdifan May 30, 2022
bc80bf1
fix mypy
dengdifan May 30, 2022
c584a58
properly memory limitation for forecasting example
dengdifan May 30, 2022
0e37178
fix pre-commit
dengdifan May 30, 2022
1cf31b2
maint dataloader
dengdifan May 31, 2022
a8fa53c
remove unused auto-regressive arguments
dengdifan May 31, 2022
a8bd54d
fix pre-commit
dengdifan May 31, 2022
609ccf1
maint
dengdifan May 31, 2022
168b7cf
maint mypy
dengdifan May 31, 2022
88c2354
mypy!!!
dengdifan May 31, 2022
374cc1d
pre-commit
dengdifan May 31, 2022
4898ca5
mypyyyyyyyyyyyyyyyyyyyyyyyy
dengdifan May 31, 2022
694eebb
maint
dengdifan Jun 13, 2022
abd3900
move forcasting requirements to extras_require
dengdifan Jun 13, 2022
776aa84
bring eval_test to tae
dengdifan Jun 14, 2022
f70e2b3
make rh2epm consistent with SMAC4HPO
dengdifan Jun 14, 2022
50f6f18
remove smac4ac from smbo
dengdifan Jun 14, 2022
2663ad9
revert changes in network
dengdifan Jun 14, 2022
58eeb0c
revert changes in trainer
dengdifan Jun 14, 2022
b86908f
revert format changes
dengdifan Jun 14, 2022
68d8a25
move constant_forecasting to constatn
dengdifan Jun 14, 2022
dac5cdd
additional annotate for base pipeline
dengdifan Jun 14, 2022
7f2d394
move forecasting check to tae
dengdifan Jun 14, 2022
e43d70a
maint time series refit dataset
dengdifan Jun 14, 2022
dc48b9d
fix test
dengdifan Jun 14, 2022
1e7253a
workflow for extra requirements
dengdifan Jun 14, 2022
83e2469
docs for time series dataset
dengdifan Jun 14, 2022
1671992
fix pre-commit
dengdifan Jun 14, 2022
97d3835
docs for dataset
dengdifan Jun 14, 2022
889c5e9
maint docstring
dengdifan Jun 14, 2022
f68dc18
merge target scaler to one file
dengdifan Jun 14, 2022
dc4f510
fix forecasting init cfgs
dengdifan Jun 14, 2022
951ef4e
remove redudant pipeline configs
dengdifan Jun 14, 2022
10f0c83
maint
dengdifan Jun 14, 2022
8574c6f
SMAC4HPO instead of SMAC4AC in smbo (will be reverted further if stud…
dengdifan Jun 15, 2022
86e39bc
fixed docstrign for RNN and Transformer Decoder
dengdifan Jun 15, 2022
21fbcb2
uniformed docstrings for smbo and base task
dengdifan Jun 15, 2022
ee66c25
correct encoder to decoder in decoder.init
dengdifan Jun 15, 2022
877a124
fix doc strings
dengdifan Jun 15, 2022
1d3a74e
add license and docstrings for NBEATS heads
dengdifan Jun 16, 2022
2516859
allow memory limit to be None
dengdifan Jun 16, 2022
fe5e587
relax test load for forecasting
dengdifan Jun 16, 2022
2c6f66f
fix docs
dengdifan Jun 16, 2022
bb7f5c5
fix pre-commit
dengdifan Jun 16, 2022
9d728b5
make test compatible with py37
dengdifan Jun 17, 2022
a331093
maint docstring
dengdifan Jun 17, 2022
8a5a91b
split forecasting_eval_train_function from eval_train_function
dengdifan Jun 17, 2022
acddd22
fix namespace for test_api from train_evaluator to tae
dengdifan Jun 17, 2022
b18ce92
maint test api for forecasting
dengdifan Jun 17, 2022
0700e61
decrease number of ensemble size of test_time_series_forecasting to r…
dengdifan Jun 17, 2022
e4328ee
flatten all the prediction for forecasting pipelines
dengdifan Jun 17, 2022
b6baef1
pre-commit fix
dengdifan Jun 17, 2022
c1de20f
Merge remote-tracking branch 'upstream/development' into time_series_…
dengdifan Jun 20, 2022
0771c8e
fix docstrings and typing
dengdifan Jun 20, 2022
d066fda
maint time series dataset docstrings
dengdifan Jun 22, 2022
f701df3
maint warning message in time_series_forecasting_train_evaluator
dengdifan Jun 22, 2022
5e970f6
fix lines that are overlength
dengdifan Jun 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test for pipelines
  • Loading branch information
dengdifan committed May 20, 2022
commit f055fd5d38d9452850cf4c32cafdd82cb0ed48eb
16 changes: 14 additions & 2 deletions autoPyTorch/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,24 @@ def _check_search_space_updates(self, include: Optional[Dict[str, Any]],
# check if component is not present in include
if include is not None and update.node_name in include.keys():
if split_hyperparameter[0] not in include[update.node_name]:
raise ValueError("Not found {} in include".format(split_hyperparameter[0]))
hp_in_component = False
for include_component in include[update.node_name]:
if include_component.startswith(split_hyperparameter[0]):
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
hp_in_component = True
break
if not hp_in_component:
raise ValueError("Not found {} in include".format(split_hyperparameter[0]))

# check if component is present in exclude
if exclude is not None and update.node_name in exclude.keys():
if split_hyperparameter[0] in exclude[update.node_name]:
raise ValueError("Found {} in exclude".format(split_hyperparameter[0]))
hp_in_component = False
for exclude_component in exclude[update.node_name]:
if exclude_component.startswith(split_hyperparameter[0]):
hp_in_component = True
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
break
if not hp_in_component:
raise ValueError("Found {} in exclude".format(split_hyperparameter[0]))

components = node.get_components()
# if hyperparameter is __choice__, check if
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def fitted_encoder(self):
return ['NBEATSEncoder']

def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
X.update({'backcast_loss_ratio': self.config['backcast_loss_ratio']})
X.update({'backcast_loss_ration': self.config['backcast_loss_ration']})
return super().transform(X)

@staticmethod
Expand Down Expand Up @@ -275,8 +275,8 @@ def get_hyperparameter_search_space(
value_range=(0, 0.8),
default_value=0.1,
),
backcast_loss_ratio: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="backcast_loss_ratio",
backcast_loss_ration: HyperparameterSearchSpace = HyperparameterSearchSpace(
hyperparameter="backcast_loss_ration",
value_range=(0., 1.),
default_value=1.,
)
Expand Down Expand Up @@ -315,7 +315,7 @@ def get_hyperparameter_search_space(
use_dropout: if dropout is applied
normalization: if normalization is applied
dropout: dropout value, if use_dropout is set as True
backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss.
backcast_loss_ration: weight of backcast in comparison to forecast when calculating the loss.
A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and
forecast lengths). Defaults to 0.0, i.e. no weight.
Returns:
Expand All @@ -329,7 +329,7 @@ def get_hyperparameter_search_space(
# General Hyperparameters
add_hyperparameter(cs, activation, CategoricalHyperparameter)
add_hyperparameter(cs, normalization, CategoricalHyperparameter)
add_hyperparameter(cs, backcast_loss_ratio, UniformFloatHyperparameter)
add_hyperparameter(cs, backcast_loss_ration, UniformFloatHyperparameter)

cs.add_hyperparameter(n_beats_type)
# N-BEATS-G
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def evaluate(self, test_loader: torch.utils.data.DataLoader, epoch: int,
Dict[str, float]: scores for each desired metric
"""
if not isinstance(self.model, (ForecastingDeepARNet, ForecastingSeq2SeqNet)):
# To save time, we simply make one step prediction for DeepAR and Seq2Seq
# To save time, we simply make one-step prediction for DeepAR and Seq2Seq
self.model.eval()
if isinstance(self.model, ForecastingDeepARNet):
self.model.only_generate_future_dist = True
Expand Down
1 change: 1 addition & 0 deletions autoPyTorch/pipeline/create_searchspace_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def find_active_choices(
) -> List[str]:
if not hasattr(node, "get_available_components"):
raise ValueError()

available_components = node.get_available_components(dataset_properties,
include=include,
exclude=exclude)
Expand Down
113 changes: 64 additions & 49 deletions autoPyTorch/pipeline/time_series_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,17 @@ def _get_hyperparameter_search_space(self,
forbidden_regression_losses_all.append(forbidden_hp_dist)
"""

# NBEATS only works with NoEmbedding
if 'network_backbone:flat_encoder:__choice__' in cs:
hp_flat_encoder = cs.get_hyperparameter('network_backbone:flat_encoder:__choice__')
if 'NBEATSEncoder' in hp_flat_encoder.choices:
cs.add_forbidden_clause(ForbiddenAndConjunction(
ForbiddenEqualsClause(hp_flat_encoder, 'NBEATSEncoder'),
cs.get_hyperparameter(
'network_embedding:__choice__'), 'LearnedEntityEmbedding')
)


# dist_cls and auto_regressive are only activate if the network outputs distribution
if 'loss' in self.named_steps.keys() and 'network_backbone' in self.named_steps.keys():
hp_loss = cs.get_hyperparameter('loss:__choice__')
Expand All @@ -225,44 +236,47 @@ def _get_hyperparameter_search_space(self,
forbidden_hp_dist = ForbiddenAndConjunction(forbidden_hp_dist, forbidden_hp_regression_loss)
forbidden_losses_all.append(forbidden_hp_dist)

decoder_auto_regressive = cs.get_hyperparameter("network_backbone:seq_encoder:decoder_auto_regressive")
forecast_strategy = cs.get_hyperparameter("loss:DistributionLoss:forecast_strategy")
use_tf = cs.get_hyperparameter("network_backbone:seq_encoder:use_temporal_fusion")

if True in decoder_auto_regressive.choices and\
'sample' in forecast_strategy.choices and True in use_tf.choices:
cs.add_forbidden_clause(
ForbiddenAndConjunction(
ForbiddenEqualsClause(decoder_auto_regressive, True),
ForbiddenEqualsClause(forecast_strategy, 'sample'),
ForbiddenEqualsClause(use_tf, True)
if "network_backbone:seq_encoder:decoder_auto_regressive" in cs:
decoder_auto_regressive = cs.get_hyperparameter("network_backbone:seq_encoder:decoder_auto_regressive")
forecast_strategy = cs.get_hyperparameter("loss:DistributionLoss:forecast_strategy")
use_tf = cs.get_hyperparameter("network_backbone:seq_encoder:use_temporal_fusion")

if True in decoder_auto_regressive.choices and\
'sample' in forecast_strategy.choices and True in use_tf.choices:
cs.add_forbidden_clause(
ForbiddenAndConjunction(
ForbiddenEqualsClause(decoder_auto_regressive, True),
ForbiddenEqualsClause(forecast_strategy, 'sample'),
ForbiddenEqualsClause(use_tf, True)
)
)
)

network_flat_encoder_hp = cs.get_hyperparameter('network_backbone:flat_encoder:__choice__')

if 'MLPEncoder' in network_flat_encoder_hp.choices:
forbidden = ['MLPEncoder']
forbidden_deepAREncoder = [forbid for forbid in forbidden if forbid in network_flat_encoder_hp.choices]
for hp_ar in hp_deepAR:
if True in hp_ar.choices:
forbidden_hp_ar = ForbiddenEqualsClause(hp_ar, ar_forbidden)
forbidden_hp_mlpencoder = ForbiddenInClause(network_flat_encoder_hp, forbidden_deepAREncoder)
forbidden_hp_ar_mlp = ForbiddenAndConjunction(forbidden_hp_ar, forbidden_hp_mlpencoder)
forbidden_losses_all.append(forbidden_hp_ar_mlp)

forecast_strategy = cs.get_hyperparameter('loss:DistributionLoss:forecast_strategy')
if 'mean' in forecast_strategy.choices:
for hp_ar in hp_deepAR:
if True in hp_ar.choices:

forbidden_hp_ar = ForbiddenEqualsClause(hp_ar, ar_forbidden)
forbidden_hp_forecast_strategy = ForbiddenEqualsClause(forecast_strategy, 'mean')
forbidden_hp_ar_forecast_strategy = ForbiddenAndConjunction(forbidden_hp_ar,
forbidden_hp_forecast_strategy)
forbidden_losses_all.append(forbidden_hp_ar_forecast_strategy)

cs.add_forbidden_clauses(forbidden_losses_all)
if 'network_backbone:flat_encoder:__choice__' in cs:
network_flat_encoder_hp = cs.get_hyperparameter('network_backbone:flat_encoder:__choice__')

if 'MLPEncoder' in network_flat_encoder_hp.choices:
forbidden = ['MLPEncoder']
forbidden_deepAREncoder = [forbid for forbid in forbidden if forbid in network_flat_encoder_hp.choices]
for hp_ar in hp_deepAR:
if True in hp_ar.choices:
forbidden_hp_ar = ForbiddenEqualsClause(hp_ar, ar_forbidden)
forbidden_hp_mlpencoder = ForbiddenInClause(network_flat_encoder_hp, forbidden_deepAREncoder)
forbidden_hp_ar_mlp = ForbiddenAndConjunction(forbidden_hp_ar, forbidden_hp_mlpencoder)
forbidden_losses_all.append(forbidden_hp_ar_mlp)

if 'loss:DistributionLoss:forecast_strategy' in cs:
forecast_strategy = cs.get_hyperparameter('loss:DistributionLoss:forecast_strategy')
if 'mean' in forecast_strategy.choices:
for hp_ar in hp_deepAR:
if True in hp_ar.choices:

forbidden_hp_ar = ForbiddenEqualsClause(hp_ar, ar_forbidden)
forbidden_hp_forecast_strategy = ForbiddenEqualsClause(forecast_strategy, 'mean')
forbidden_hp_ar_forecast_strategy = ForbiddenAndConjunction(forbidden_hp_ar,
forbidden_hp_forecast_strategy)
forbidden_losses_all.append(forbidden_hp_ar_forecast_strategy)
if forbidden_losses_all:
cs.add_forbidden_clauses(forbidden_losses_all)

# NBEATS
network_encoder_hp = cs.get_hyperparameter("network_backbone:__choice__")
Expand All @@ -275,21 +289,22 @@ def _get_hyperparameter_search_space(self,
forbidden_loss_non_regression = ForbiddenInClause(hp_loss, loss_non_regression)
forbidden_backcast = ForbiddenEqualsClause(data_loader_backcast, True)

hp_flat_encoder = cs.get_hyperparameter("network_backbone:flat_encoder:__choice__")

# Ensure that NBEATS encoder only works with NBEATS decoder
if 'NBEATSEncoder' in hp_flat_encoder.choices:
forbidden_NBEATS.append(ForbiddenAndConjunction(
ForbiddenEqualsClause(hp_flat_encoder, 'NBEATSEncoder'),
forbidden_loss_non_regression)
)
transform_time_features = "data_loader:transform_time_features"
if transform_time_features in cs:
hp_ttf = cs.get_hyperparameter(transform_time_features)
if 'network_backbone:flat_encoder:__choice__' in cs:
hp_flat_encoder = cs.get_hyperparameter("network_backbone:flat_encoder:__choice__")

# Ensure that NBEATS encoder only works with NBEATS decoder
if 'NBEATSEncoder' in hp_flat_encoder.choices:
forbidden_NBEATS.append(ForbiddenAndConjunction(
ForbiddenEqualsClause(hp_flat_encoder, 'NBEATSEncoder'),
ForbiddenEqualsClause(hp_ttf, True))
forbidden_loss_non_regression)
)
transform_time_features = "data_loader:transform_time_features"
if transform_time_features in cs:
hp_ttf = cs.get_hyperparameter(transform_time_features)
forbidden_NBEATS.append(ForbiddenAndConjunction(
ForbiddenEqualsClause(hp_flat_encoder, 'NBEATSEncoder'),
ForbiddenEqualsClause(hp_ttf, True))
)

forbidden_NBEATS.append(ForbiddenAndConjunction(
forbidden_backcast,
Expand Down Expand Up @@ -320,7 +335,7 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]]) -> L
default_dataset_properties.update(dataset_properties)

if not default_dataset_properties.get("uni_variant", False):
steps.extend([("imputer", TimeSeriesFeatureImputer(random_state=self.random_state)),
steps.extend([("impute", TimeSeriesFeatureImputer(random_state=self.random_state)),
("scaler", BaseScaler(random_state=self.random_state)),
('encoding', TimeSeriesEncoderChoice(default_dataset_properties,
random_state=self.random_state)),
Expand Down
41 changes: 23 additions & 18 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def generate_forecasting_features(feature_type, length):
return features, targets, input_validator.fit(features, targets, start_times=start_times)


def get_forecasting_datamangaer(X, y, validator, with_y_test=True, forecast_horizon=5, freq='1D'):
def get_forecasting_datamangaer(X, y, validator, with_y_test=True, forecast_horizon=3, freq='1D'):
if X is not None:
X_test = []
for x in X:
Expand Down Expand Up @@ -761,7 +761,7 @@ def get_forecasting_datamangaer(X, y, validator, with_y_test=True, forecast_hori
return datamanager


def get_forecasting_fit_dictionary(datamanager, backend, budget_type='epochs'):
def get_forecasting_fit_dictionary(datamanager, backend, forecasting_budgets='epochs'):
info = datamanager.get_required_dataset_info()

dataset_properties = datamanager.get_dataset_properties(get_dataset_requirements(info))
Expand All @@ -775,29 +775,29 @@ def get_forecasting_fit_dictionary(datamanager, backend, budget_type='epochs'):
'working_dir': './tmp/example_ensemble_1', # Hopefully generated by backend
'device': 'cpu',
'torch_num_threads': 1,
'early_stopping': 10,
'early_stopping': 1,
'use_tensorboard_logger': False,
'use_pynisher': False,
'metrics_during_training': False,
'seed': 1,
'budget_type': 'epochs',
'epochs': 5,
'epochs': 1,
'split_id': 0,
'backend': backend,
'logger_port': logging.handlers.DEFAULT_TCP_LOGGING_PORT,
}
if budget_type == 'epochs':
fit_dictionary.update({'budget_type': 'epochs',
'epochs': 5})
elif budget_type == 'resolution':
fit_dictionary.update({'budget_type': 'resolution',
'sample_interval': 10})
elif budget_type == 'num_sample_per_seq':
fit_dictionary.update({'budget_type': 'num_samples',
'fraction_samples_per_seq': 0.1})
elif budget_type == 'num_seq':
fit_dictionary.update({'budget_type': 'num_samples',
'fraction_seq': 0.1})
if forecasting_budgets == 'epochs':
fit_dictionary.update({'forecasting_budgets': 'epochs',
'epochs': 1})
elif forecasting_budgets == 'resolution':
fit_dictionary.update({'forecasting_budgets': 'resolution',
'sample_interval': 2})
elif forecasting_budgets == 'num_sample_per_seq':
fit_dictionary.update({'forecasting_budgets': 'num_sample_per_seq',
'fraction_samples_per_seq': 0.5})
elif forecasting_budgets == 'num_seq':
fit_dictionary.update({'forecasting_budgets': 'num_seq',
'fraction_seq': 0.5})
else:
raise NotImplementedError
backend.save_datamanager(datamanager)
Expand Down Expand Up @@ -863,11 +863,16 @@ def get_forecasting_datamanager(request):
return datamanager


@pytest.fixture(params=['epochs'])
def forecasting_budgets(request):
return request.param


@pytest.fixture
def get_fit_dictionary_forecasting(request, backend):
def fit_dictionary_forecasting(request, forecasting_budgets, backend):
X, y, validator = get_forecasting_data(request.param)
datamanager = get_forecasting_datamangaer(X, y, validator)
return get_forecasting_fit_dictionary(datamanager, backend)
return get_forecasting_fit_dictionary(datamanager, backend, forecasting_budgets=forecasting_budgets)


# Fixtures for forecasting validators.
Expand Down
Loading