Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time series forecasting #434

Merged
merged 357 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
357 commits
Select commit Hold shift + click to select a range
f193423
new target scaler, allow NoNorm for MLP Encpder
dengdifan Dec 22, 2021
752a58f
allow sampling full sequences
dengdifan Dec 22, 2021
2ab286b
integrate SeqBuilder to SequenceCollector
dengdifan Dec 22, 2021
d90d630
restore SequenceBuilder to reduce memory usage
dengdifan Dec 22, 2021
adcf8a0
move scaler to network
dengdifan Dec 22, 2021
c3da078
lag sequence
dengdifan Dec 22, 2021
5d67973
merge encoder and decoder as a single pipeline
dengdifan Dec 30, 2021
45d2078
faster lag_seq builder
dengdifan Jan 3, 2022
e4c5358
maint
dengdifan Jan 4, 2022
a0a01b3
new init, faster DeepAR inference in trainer
dengdifan Jan 5, 2022
27e0eb0
more losses types
dengdifan Jan 6, 2022
3710cb2
maint
dengdifan Jan 6, 2022
70d82d9
new Transformer models, allow RNN to do deepAR inference
dengdifan Jan 9, 2022
aa0221a
maint
dengdifan Jan 12, 2022
0a8b5f2
maint
dengdifan Jan 12, 2022
7a7b68d
maint
dengdifan Jan 13, 2022
eeeda02
maint
dengdifan Jan 13, 2022
738f18d
reduced search space for Transformer
dengdifan Jan 13, 2022
353b5c5
reduced init design
dengdifan Jan 13, 2022
e8db57b
maint
dengdifan Jan 13, 2022
8dec08c
maint
dengdifan Jan 14, 2022
7008cec
maint
dengdifan Jan 14, 2022
c7d401e
maint
dengdifan Jan 14, 2022
df15a2b
faster forecasting
dengdifan Jan 16, 2022
ca6a47d
maint
dengdifan Jan 16, 2022
5c9b10c
allow singel fidelity
dengdifan Jan 16, 2022
64580f7
maint
dengdifan Jan 17, 2022
df69c79
fix budget num_seq
dengdifan Jan 17, 2022
91acd5b
faster sampler and lagger
dengdifan Jan 17, 2022
fa96b27
maint
dengdifan Jan 17, 2022
995684a
maint
dengdifan Jan 17, 2022
d3a0e31
maint deepAR
dengdifan Jan 18, 2022
d8b3892
maint
dengdifan Jan 18, 2022
5e9fbae
maint
dengdifan Jan 19, 2022
3eb197e
cross validation
dengdifan Jan 19, 2022
6f4fdf1
allow holdout for smaller datasets
dengdifan Jan 19, 2022
8de8e50
smac4ac to smac4hpo
dengdifan Jan 19, 2022
54fbe59
maint
dengdifan Jan 21, 2022
918776f
maint
dengdifan Jan 21, 2022
8570951
allow to change decoder search space
dengdifan Jan 21, 2022
90edcfb
more resampling strategy, more options for MLP
dengdifan Jan 24, 2022
df7f568
reduced NBEATS
dengdifan Jan 24, 2022
6a22e4d
subsampler for val loader
dengdifan Jan 25, 2022
01ee1f1
rng for dataloader sampler
dengdifan Jan 25, 2022
5dcdc4e
maint
dengdifan Jan 26, 2022
f32a12b
remove generator as it cannot be pickled
dengdifan Jan 26, 2022
344e7df
allow lower fidelity to evaluate less test instances
dengdifan Jan 28, 2022
94891f2
fix dummy forecastro isues
dengdifan Jan 30, 2022
8fc51b1
maint
dengdifan Jan 30, 2022
0b34683
add gluonts as requirement
dengdifan Jan 30, 2022
f23840b
more data for val set for larger dataset
dengdifan Feb 1, 2022
dc84904
maint
dengdifan Feb 2, 2022
6466d11
Merge branch 'refactor_development_time_series' of https://github.com…
dengdifan Feb 2, 2022
6981553
maint
dengdifan Feb 3, 2022
3fda94a
fix nbeats decoder
Feb 14, 2022
d95e230
new dataset interface
dengdifan Feb 16, 2022
d185609
Merge branch 'refactor_development_time_series' of https://github.com…
dengdifan Feb 16, 2022
d5459fa
resolve conflict
dengdifan Feb 16, 2022
510cc5a
maint
dengdifan Feb 16, 2022
3806fe2
allow encoder to receive input from different sources
dengdifan Feb 16, 2022
9251bbc
multi blocks hp design
dengdifan Feb 18, 2022
5617db6
maint
dengdifan Feb 20, 2022
d04cb04
correct hp updates
dengdifan Feb 20, 2022
7881bb5
first trial on nested conjunction
dengdifan Feb 21, 2022
d7bff6e
maint
dengdifan Feb 21, 2022
2153bc2
fit for deep AR model (needs to be reverted when the issue in ConfigS…
dengdifan Feb 21, 2022
b2063e7
adjust backbones to fit new structure
dengdifan Feb 23, 2022
59cee13
further API changes
dengdifan Feb 28, 2022
b2b5580
tft temporal fusion decoder
dengdifan Feb 28, 2022
57461b9
construct network
dengdifan Mar 2, 2022
20eb852
cells for networks
dengdifan Mar 2, 2022
f5cede7
forecasting backbones
dengdifan Mar 4, 2022
50c559e
maint
dengdifan Mar 4, 2022
2dd0b11
maint
dengdifan Mar 6, 2022
0f0dbf0
move tft layer to backbone
dengdifan Mar 7, 2022
9e68629
maint
dengdifan Mar 7, 2022
ed99ba1
quantile loss
dengdifan Mar 7, 2022
45535ba
maint
dengdifan Mar 8, 2022
9fac9fe
maint
dengdifan Mar 8, 2022
75570c2
maint
dengdifan Mar 8, 2022
2f954cd
maint
dengdifan Mar 8, 2022
31f8ddc
maint
dengdifan Mar 8, 2022
7f4911e
maint
dengdifan Mar 8, 2022
2e31fdb
forecasting init configs
dengdifan Mar 8, 2022
125921c
add forbidden
dengdifan Mar 9, 2022
8d704d1
maint
dengdifan Mar 10, 2022
e646672
maint
dengdifan Mar 10, 2022
a2ad3fe
maint
dengdifan Mar 10, 2022
200691c
remove shift data
dengdifan Mar 11, 2022
538f24e
maint
dengdifan Mar 11, 2022
12ccf4b
maint
dengdifan Mar 11, 2022
4d6853d
copy dataset_properties for each refit iteration
dengdifan Mar 11, 2022
34d556a
maint and new init
dengdifan Mar 14, 2022
37501ef
Tft forecating with features (#6)
dengdifan Mar 16, 2022
5746541
fix loss computation in QuantileLoss
dengdifan Mar 16, 2022
b1fbece
fixed scaler computation
dengdifan Mar 18, 2022
683ccf5
maint
dengdifan Mar 19, 2022
95d2ab5
fix dataset
Mar 22, 2022
baaf34f
adjust window_size to seasonality
Mar 22, 2022
897cd74
maint scaling
dengdifan Mar 23, 2022
a09ddbb
fix uncorrect Seq2Seq scaling
dengdifan Mar 23, 2022
c1dda0a
fix sampling for seq2seq
dengdifan Mar 25, 2022
49ee49c
maint
dengdifan Mar 25, 2022
dc97df2
fix scaling in NBEATS
dengdifan Mar 25, 2022
399572c
move time feature computation to dataset
dengdifan Mar 28, 2022
7154308
maint
dengdifan Mar 30, 2022
1ba08fe
fix feature computation
dengdifan Mar 31, 2022
04a69d8
maint
dengdifan Mar 31, 2022
471db34
multi-variant feature validator
dengdifan Apr 13, 2022
16cf754
resolve conflicts
dengdifan Apr 13, 2022
cc77b51
maint
dengdifan Apr 13, 2022
fe6fb1f
validator for multi-variant series
dengdifan Apr 13, 2022
9264f89
feature validator
dengdifan Apr 14, 2022
aa3f7a6
multi-variant datasets
dengdifan Apr 14, 2022
974f8ff
observed targets
dengdifan Apr 14, 2022
37dd821
stucture adjustment
dengdifan Apr 20, 2022
1a6e19d
refactory ts tasks and preprocessing
dengdifan Apr 22, 2022
075c6e6
allow nan in targets
dengdifan Apr 22, 2022
2487117
preprocessing for time series
dengdifan Apr 22, 2022
86e4e3c
maint
dengdifan Apr 25, 2022
2c9944c
forecasting pipeline
dengdifan Apr 25, 2022
7eb5139
maint
dengdifan Apr 26, 2022
22fc0bc
embedding and maint
dengdifan Apr 26, 2022
1759fdf
move targets to the tail of the features
dengdifan Apr 26, 2022
9652c80
maint
dengdifan Apr 26, 2022
1d89636
static features
dengdifan Apr 27, 2022
282d63b
adjsut scaler to static features
dengdifan Apr 27, 2022
fb8b805
remove static features from forward dict
dengdifan Apr 27, 2022
533f12d
test transform
dengdifan Apr 27, 2022
f8be97c
maint
dengdifan Apr 28, 2022
e8c9071
test sets
dengdifan Apr 28, 2022
2779015
adjust dataset to allow future known features
dengdifan Apr 29, 2022
1a1fe68
maint
dengdifan Apr 29, 2022
f4ad355
maint
dengdifan Apr 29, 2022
79ef7a7
flake8
dengdifan Apr 29, 2022
88977e0
synchronise with development
dengdifan May 2, 2022
b269ff8
recover timeseries
dengdifan May 2, 2022
31f9e43
maint
dengdifan May 2, 2022
67ea836
maint
dengdifan May 2, 2022
80b8ac2
limit memory usage tae
dengdifan May 2, 2022
d01e2a7
revert test api
dengdifan May 2, 2022
3be7be9
test for targets
dengdifan May 3, 2022
77dcb7c
not allow sparse forecasting target
dengdifan May 3, 2022
6932199
test for data validator
dengdifan May 4, 2022
ee97108
test for validations
dengdifan May 5, 2022
b7f51f2
test on TimeSeriesSequence
dengdifan May 5, 2022
08bfe18
maint
dengdifan May 5, 2022
478ad68
test for resampling
dengdifan May 6, 2022
1986593
test for dataset 1
dengdifan May 7, 2022
112c876
test for datasets
dengdifan May 8, 2022
235e310
test on tae
dengdifan May 9, 2022
9d8dd0b
maint
dengdifan May 9, 2022
dc4b602
all evaluator to evalaute test sets
dengdifan May 10, 2022
e8cf8cb
tests on losses
dengdifan May 10, 2022
e5b1c47
test for metrics
dengdifan May 10, 2022
3f47489
forecasting preprocessing
dengdifan May 10, 2022
835055d
maint
dengdifan May 11, 2022
ef9e44e
finish test for preprocessing
dengdifan May 11, 2022
21b3958
test for data loader
dengdifan May 12, 2022
101ddbc
tests for dataloader
dengdifan May 13, 2022
7318086
maint
dengdifan May 13, 2022
cf2c982
test for target scaling 1
dengdifan May 13, 2022
8b7ef61
test for target scaer
dengdifan May 15, 2022
1025b93
test for training loss
dengdifan May 15, 2022
6f68633
maint
dengdifan May 16, 2022
570408d
test for network backbone
dengdifan May 16, 2022
7d42007
test for backbone base
dengdifan May 17, 2022
2033075
test for flat encoder
dengdifan May 17, 2022
c6e2239
test for seq encoder
dengdifan May 17, 2022
727e48e
test for seqencoder
dengdifan May 18, 2022
23dde67
maint
dengdifan May 18, 2022
4d9fe30
test for recurrent decoders
dengdifan May 19, 2022
eb5a7ec
test for network
dengdifan May 19, 2022
0ea372e
maint
dengdifan May 19, 2022
1b7ebbe
test for architecture
dengdifan May 20, 2022
f055fd5
test for pipelines
dengdifan May 20, 2022
ccab50e
fixed sampler
dengdifan May 21, 2022
54acaa6
maint sampler
dengdifan May 21, 2022
da6e92d
resolve conflict between embedding and net encoder
dengdifan May 21, 2022
fba012c
fix scaling
dengdifan May 21, 2022
2ed1197
allow transform for test dataloader
dengdifan May 21, 2022
95eb783
maint dataloader
dengdifan May 21, 2022
8035221
fix updates
dengdifan May 22, 2022
f3cb2de
fix dataset
dengdifan May 23, 2022
0af1217
tests on api, initial design on multi-variant
dengdifan May 24, 2022
c717fae
maint
dengdifan May 24, 2022
78d7a51
fix dataloader
dengdifan May 24, 2022
fa5cc75
move test with for loop to unittest.subtest
dengdifan May 24, 2022
2d2e039
flake 8 and update requirement
dengdifan May 24, 2022
a1c7930
mypy
dengdifan May 24, 2022
ba96c37
validator for pd dataframe
dengdifan May 27, 2022
cdcdb5a
allow series idx for api
dengdifan May 27, 2022
43671dd
maint
dengdifan May 30, 2022
806afb3
examples for forecasting
dengdifan May 30, 2022
bc80bf1
fix mypy
dengdifan May 30, 2022
c584a58
properly memory limitation for forecasting example
dengdifan May 30, 2022
0e37178
fix pre-commit
dengdifan May 30, 2022
1cf31b2
maint dataloader
dengdifan May 31, 2022
a8fa53c
remove unused auto-regressive arguments
dengdifan May 31, 2022
a8bd54d
fix pre-commit
dengdifan May 31, 2022
609ccf1
maint
dengdifan May 31, 2022
168b7cf
maint mypy
dengdifan May 31, 2022
88c2354
mypy!!!
dengdifan May 31, 2022
374cc1d
pre-commit
dengdifan May 31, 2022
4898ca5
mypyyyyyyyyyyyyyyyyyyyyyyyy
dengdifan May 31, 2022
694eebb
maint
dengdifan Jun 13, 2022
abd3900
move forcasting requirements to extras_require
dengdifan Jun 13, 2022
776aa84
bring eval_test to tae
dengdifan Jun 14, 2022
f70e2b3
make rh2epm consistent with SMAC4HPO
dengdifan Jun 14, 2022
50f6f18
remove smac4ac from smbo
dengdifan Jun 14, 2022
2663ad9
revert changes in network
dengdifan Jun 14, 2022
58eeb0c
revert changes in trainer
dengdifan Jun 14, 2022
b86908f
revert format changes
dengdifan Jun 14, 2022
68d8a25
move constant_forecasting to constatn
dengdifan Jun 14, 2022
dac5cdd
additional annotate for base pipeline
dengdifan Jun 14, 2022
7f2d394
move forecasting check to tae
dengdifan Jun 14, 2022
e43d70a
maint time series refit dataset
dengdifan Jun 14, 2022
dc48b9d
fix test
dengdifan Jun 14, 2022
1e7253a
workflow for extra requirements
dengdifan Jun 14, 2022
83e2469
docs for time series dataset
dengdifan Jun 14, 2022
1671992
fix pre-commit
dengdifan Jun 14, 2022
97d3835
docs for dataset
dengdifan Jun 14, 2022
889c5e9
maint docstring
dengdifan Jun 14, 2022
f68dc18
merge target scaler to one file
dengdifan Jun 14, 2022
dc4f510
fix forecasting init cfgs
dengdifan Jun 14, 2022
951ef4e
remove redudant pipeline configs
dengdifan Jun 14, 2022
10f0c83
maint
dengdifan Jun 14, 2022
8574c6f
SMAC4HPO instead of SMAC4AC in smbo (will be reverted further if stud…
dengdifan Jun 15, 2022
86e39bc
fixed docstrign for RNN and Transformer Decoder
dengdifan Jun 15, 2022
21fbcb2
uniformed docstrings for smbo and base task
dengdifan Jun 15, 2022
ee66c25
correct encoder to decoder in decoder.init
dengdifan Jun 15, 2022
877a124
fix doc strings
dengdifan Jun 15, 2022
1d3a74e
add license and docstrings for NBEATS heads
dengdifan Jun 16, 2022
2516859
allow memory limit to be None
dengdifan Jun 16, 2022
fe5e587
relax test load for forecasting
dengdifan Jun 16, 2022
2c6f66f
fix docs
dengdifan Jun 16, 2022
bb7f5c5
fix pre-commit
dengdifan Jun 16, 2022
9d728b5
make test compatible with py37
dengdifan Jun 17, 2022
a331093
maint docstring
dengdifan Jun 17, 2022
8a5a91b
split forecasting_eval_train_function from eval_train_function
dengdifan Jun 17, 2022
acddd22
fix namespace for test_api from train_evaluator to tae
dengdifan Jun 17, 2022
b18ce92
maint test api for forecasting
dengdifan Jun 17, 2022
0700e61
decrease number of ensemble size of test_time_series_forecasting to r…
dengdifan Jun 17, 2022
e4328ee
flatten all the prediction for forecasting pipelines
dengdifan Jun 17, 2022
b6baef1
pre-commit fix
dengdifan Jun 17, 2022
c1de20f
Merge remote-tracking branch 'upstream/development' into time_series_…
dengdifan Jun 20, 2022
0771c8e
fix docstrings and typing
dengdifan Jun 20, 2022
d066fda
maint time series dataset docstrings
dengdifan Jun 22, 2022
f701df3
maint warning message in time_series_forecasting_train_evaluator
dengdifan Jun 22, 2022
5e970f6
fix lines that are overlength
dengdifan Jun 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix doc strings
  • Loading branch information
dengdifan committed Jun 15, 2022
commit 877a12481b6beafb235447e9521aedded9660838
90 changes: 80 additions & 10 deletions autoPyTorch/evaluation/time_series_forecasting_train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,86 @@


class TimeSeriesForecastingTrainEvaluator(TrainEvaluator):
"""
This class is similar to the TrainEvaluator. Except that given the specific

Attributes:
backend (Backend):
An object to interface with the disk storage. In particular, allows to
access the train and test datasets
queue (Queue):
Each worker available will instantiate an evaluator, and after completion,
it will return the evaluation result via a multiprocessing queue
metric (autoPyTorchMetric):
A scorer object that is able to evaluate how good a pipeline was fit. It
is a wrapper on top of the actual score method (a wrapper on top of scikit
lean accuracy for example) that formats the predictions accordingly.
budget: (float):
The amount of epochs/time a configuration is allowed to run.
budget_type (str):
The budget type, which can be epochs or time
pipeline_config (Optional[Dict[str, Any]]):
Defines the content of the pipeline being evaluated. For example, it
contains pipeline specific settings like logging name, or whether or not
to use tensorboard.
configuration (Union[int, str, Configuration]):
Determines the pipeline to be constructed. A dummy estimator is created for
integer configurations, a traditional machine learning pipeline is created
for string based configuration, and NAS is performed when a configuration
object is passed.
seed (int):
A integer that allows for reproducibility of results
output_y_hat_optimization (bool):
Whether this worker should output the target predictions, so that they are
stored on disk. Fundamentally, the resampling strategy might shuffle the
Y_train targets, so we store the split in order to re-use them for ensemble
selection.
num_run (Optional[int]):
An identifier of the current configuration being fit. This number is unique per
configuration.
include (Optional[Dict[str, Any]]):
An optional dictionary to include components of the pipeline steps.
exclude (Optional[Dict[str, Any]]):
An optional dictionary to exclude components of the pipeline steps.
disable_file_output (Optional[List[Union[str, DisableFileOutputParameters]]]):
Used as a list to pass more fine-grained
information on what to save. Must be a member of `DisableFileOutputParameters`.
Allowed elements in the list are:

+ `y_optimization`:
do not save the predictions for the optimization set,
which would later on be used to build an ensemble. Note that SMAC
optimizes a metric evaluated on the optimization set.
+ `pipeline`:
do not save any individual pipeline files
+ `pipelines`:
In case of cross validation, disables saving the joint model of the
pipelines fit on each fold.
+ `y_test`:
do not save the predictions for the test set.
+ `all`:
do not save any of the above.
For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`.
init_params (Optional[Dict[str, Any]]):
Optional argument that is passed to each pipeline step. It is the equivalent of
kwargs for the pipeline steps.
logger_port (Optional[int]):
Logging is performed using a socket-server scheme to be robust against many
parallel entities that want to write to the same file. This integer states the
socket port for the communication channel. If None is provided, a traditional
logger is used.
all_supported_metrics (bool):
Whether all supported metric should be calculated for every configuration.
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
An object used to fine tune the hyperparameter search space of the pipeline
max_budget (float):
maximal budget value available for the optimizer. This is applied to compute the size of the proxy
validation sets
min_num_test_instances (Optional[int]):
minimal number of instances to be validated. We do so to ensure that there are enough instances in
the validation set

"""
def __init__(self, backend: Backend, queue: Queue,
metric: autoPyTorchMetric,
budget: float,
Expand All @@ -41,16 +121,6 @@ def __init__(self, backend: Backend, queue: Queue,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
max_budget: float = 1.0,
min_num_test_instances: Optional[int] = None) -> None:
"""
Attributes:
max_budget (Optional[float]):
maximal budget the optimizer could allocate
min_num_test_instances: Optional[int]
minimal number of validation instances to be evaluated, if the size of the validation set is greater
than this value, then less instances from validation sets will be evaluated. The other predictions
will be filled with dummy predictor

"""
super(TimeSeriesForecastingTrainEvaluator, self).__init__(
backend=backend,
queue=queue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def get_available_components(
include/exclude directives, as well as the dataset properties

Args:
include (Optional[Dict[str, Any]]):
what hyper-parameter configurations to honor when creating the configuration space
exclude (Optional[Dict[str, Any]]):
what hyper-parameter configurations to remove from the configuration space
dataset_properties (Optional[Dict[str, BaseDatasetPropertiesType]]):
Characteristics of the dataset to guide the pipeline choices of components
include (Optional[Dict[str, Any]]):
what hyper-parameter configurations to honor when creating the configuration space
exclude (Optional[Dict[str, Any]]):
what hyper-parameter configurations to remove from the configuration space
dataset_properties (Optional[Dict[str, BaseDatasetPropertiesType]]):
Characteristics of the dataset to guide the pipeline choices of components

Returns:
Dict[str, autoPyTorchComponent]: A filtered dict of learning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,41 @@ def get_lagged_subsequences_inference(


class AbstractForecastingNet(nn.Module):
"""
This is a basic forecasting network. It is only composed of a embedding net, an encoder and a head (including
MLP decoder and the final head).

This structure is active when the decoder is a MLP with auto_regressive set as false

Attributes:
network_structure (NetworkStructure):
network structure information
network_embedding (nn.Module):
network embedding
network_encoder (Dict[str, EncoderBlockInfo]):
Encoder network, could be selected to return a sequence or a 2D Matrix
network_decoder (Dict[str, DecoderBlockInfo]):
network decoder
temporal_fusion Optional[TemporalFusionLayer]:
Temporal Fusion Layer
network_head (nn.Module):
network head, maps the output of decoder to the final output
dataset_properties (Dict):
dataset properties
auto_regressive (bool):
if the model is auto-regressive model
output_type (str):
the form that the network outputs. It could be regression, distribution or quantile
forecast_strategy (str):
only valid if output_type is distribution or quantile, how the network transforms
its output to predicted values, could be mean or sample
num_samples (int):
only valid if output_type is not regression and forecast_strategy is sample. This indicates the
number of the points to sample when doing prediction
aggregation (str):
only valid if output_type is not regression and forecast_strategy is sample. The way that the samples
are aggregated. We could take their mean or median values.
"""
future_target_required = False
dtype = torch.float

Expand All @@ -178,41 +213,6 @@ def __init__(self,
num_samples: int = 50,
aggregation: str = 'mean'
):
"""
This is a basic forecasting network. It is only composed of a embedding net, an encoder and a head (including
MLP decoder and the final head).

This structure is active when the decoder is a MLP with auto_regressive set as false

Args:
network_structure (NetworkStructure):
network structure information
network_embedding (nn.Module):
network embedding
network_encoder (Dict[str, EncoderBlockInfo]):
Encoder network, could be selected to return a sequence or a 2D Matrix
network_decoder (Dict[str, DecoderBlockInfo]):
network decoder
temporal_fusion Optional[TemporalFusionLayer]:
Temporal Fusion Layer
network_head (nn.Module):
network head, maps the output of decoder to the final output
dataset_properties (Dict):
dataset properties
auto_regressive (bool):
if the model is auto-regressive model
output_type (str):
the form that the network outputs. It could be regression, distribution or quantile
forecast_strategy (str):
only valid if output_type is distribution or quantile, how the network transforms
its output to predicted values, could be mean or sample
num_samples (int):
only valid if output_type is not regression and forecast_strategy is sample. This indicates the
number of the points to sample when doing prediction
aggregation (str):
only valid if output_type is not regression and forecast_strategy is sample. The way that the samples
are aggregated. We could take their mean or median values.
"""
super().__init__()
self.network_structure = network_structure
self.embedding = network_embedding
Expand Down Expand Up @@ -305,6 +305,23 @@ def rescale_output(self,
loc: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
device: torch.device = torch.device('cpu')) -> ALL_NET_OUTPUT:
"""
rescale the network output to its raw scale

Args:
outputs (ALL_NET_OUTPUT):
network head output
loc (Optional[torch.Tensor]):
scaling location value
scale (Optional[torch.Tensor]):
scaling scale value
device (torch.device):
which device the output is stored

Return:
ALL_NET_OUTPUT:
rescaleed network output
"""
if isinstance(outputs, List):
return [self.rescale_output(output, loc, scale, device) for output in outputs]
if loc is not None or scale is not None:
Expand All @@ -323,17 +340,34 @@ def rescale_output(self,
return outputs

def scale_value(self,
outputs: torch.Tensor,
raw_value: torch.Tensor,
loc: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
device: torch.device = torch.device('cpu')) -> torch.Tensor:
"""
scale the outputs

Args:
raw_value (torch.Tensor):
network head output
loc (Optional[torch.Tensor]):
scaling location value
scale (Optional[torch.Tensor]):
scaling scale value
device (torch.device):
which device the output is stored

Return:
torch.Tensor:
scaled input value
"""
if loc is not None or scale is not None:
if loc is None:
outputs = outputs / scale.to(device) # type: ignore[union-attr]
outputs = raw_value / scale.to(device) # type: ignore[union-attr]
elif scale is None:
outputs = outputs - loc.to(device)
outputs = raw_value - loc.to(device)
else:
outputs = (outputs - loc.to(device)) / scale.to(device)
outputs = (raw_value - loc.to(device)) / scale.to(device)
return outputs

@abstractmethod
Expand All @@ -349,6 +383,17 @@ def forward(self,

@abstractmethod
def pred_from_net_output(self, net_output: ALL_NET_OUTPUT) -> torch.Tensor:
"""
This function is applied to transform the network head output to torch tensor to create the point prediction

Args:
net_output (ALL_NET_OUTPUT):
network head output

Return:
torch.Tensor:
point prediction
"""
raise NotImplementedError

@abstractmethod
Expand All @@ -364,6 +409,23 @@ def repeat_intermediate_values(self,
intermediate_values: List[Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]],
is_hidden_states: List[bool],
repeats: int) -> List[Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
"""
This function is often applied for auto-regressive model where we sample multiple points to form several
trajectories and we need to repeat the intermediate values to ensure that the batch sizes match

Args:
intermediate_values (List[Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]])
a list of intermediate values to be repeated
is_hidden_states (List[bool]):
if the intermediate_value is hidden states in RNN-form network, we need to consider the
hidden states differently
repeats (int):
number of repeats

Return:
List[Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
repeated values
"""
for i, (is_hx, inter_value) in enumerate(zip(is_hidden_states, intermediate_values)):
if isinstance(inter_value, torch.Tensor):
repeated_value = inter_value.repeat_interleave(repeats=repeats, dim=1 if is_hx else 0)
Expand All @@ -375,6 +437,19 @@ def repeat_intermediate_values(self,
return intermediate_values

def pad_tensor(self, tensor_to_be_padded: torch.Tensor, target_length: int) -> torch.Tensor:
"""
pad tensor to meet the required length

Args:
tensor_to_be_padded (torch.Tensor)
tensor to be padded
target_length (int):
target length

Return:
torch.Tensor:
padded tensors
"""
tensor_shape = tensor_to_be_padded.shape
padding_size = [tensor_shape[0], target_length - tensor_shape[1], tensor_shape[-1]]
tensor_to_be_padded = torch.cat([tensor_to_be_padded.new_zeros(padding_size), tensor_to_be_padded], dim=1)
Expand Down Expand Up @@ -1174,6 +1249,9 @@ def forward(self, # type: ignore[override]
past_observed_targets: Optional[torch.BoolTensor] = None,
decoder_observed_values: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor]]:

# Unlike other networks, NBEATS network is required to predict both past and future targets.
# Thereby, we return two tensors for backcast and forecast
if past_observed_targets is None:
past_observed_targets = torch.ones_like(past_targets, dtype=torch.bool)

Expand All @@ -1194,6 +1272,7 @@ def forward(self, # type: ignore[override]
forecast = torch.zeros(forcast_shape).to(self.device).flatten(1)
backcast, _ = self.encoder(past_targets, [None])
backcast = backcast[0]
# nbeats network only has one decoder block (flat decoder)
for block in self.decoder.decoder['block_1']:
backcast_block, forecast_block = block([None], backcast)

Expand Down
Loading