Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time series forecasting #434

Merged
merged 357 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
357 commits
Select commit Hold shift + click to select a range
f193423
new target scaler, allow NoNorm for MLP Encpder
dengdifan Dec 22, 2021
752a58f
allow sampling full sequences
dengdifan Dec 22, 2021
2ab286b
integrate SeqBuilder to SequenceCollector
dengdifan Dec 22, 2021
d90d630
restore SequenceBuilder to reduce memory usage
dengdifan Dec 22, 2021
adcf8a0
move scaler to network
dengdifan Dec 22, 2021
c3da078
lag sequence
dengdifan Dec 22, 2021
5d67973
merge encoder and decoder as a single pipeline
dengdifan Dec 30, 2021
45d2078
faster lag_seq builder
dengdifan Jan 3, 2022
e4c5358
maint
dengdifan Jan 4, 2022
a0a01b3
new init, faster DeepAR inference in trainer
dengdifan Jan 5, 2022
27e0eb0
more losses types
dengdifan Jan 6, 2022
3710cb2
maint
dengdifan Jan 6, 2022
70d82d9
new Transformer models, allow RNN to do deepAR inference
dengdifan Jan 9, 2022
aa0221a
maint
dengdifan Jan 12, 2022
0a8b5f2
maint
dengdifan Jan 12, 2022
7a7b68d
maint
dengdifan Jan 13, 2022
eeeda02
maint
dengdifan Jan 13, 2022
738f18d
reduced search space for Transformer
dengdifan Jan 13, 2022
353b5c5
reduced init design
dengdifan Jan 13, 2022
e8db57b
maint
dengdifan Jan 13, 2022
8dec08c
maint
dengdifan Jan 14, 2022
7008cec
maint
dengdifan Jan 14, 2022
c7d401e
maint
dengdifan Jan 14, 2022
df15a2b
faster forecasting
dengdifan Jan 16, 2022
ca6a47d
maint
dengdifan Jan 16, 2022
5c9b10c
allow singel fidelity
dengdifan Jan 16, 2022
64580f7
maint
dengdifan Jan 17, 2022
df69c79
fix budget num_seq
dengdifan Jan 17, 2022
91acd5b
faster sampler and lagger
dengdifan Jan 17, 2022
fa96b27
maint
dengdifan Jan 17, 2022
995684a
maint
dengdifan Jan 17, 2022
d3a0e31
maint deepAR
dengdifan Jan 18, 2022
d8b3892
maint
dengdifan Jan 18, 2022
5e9fbae
maint
dengdifan Jan 19, 2022
3eb197e
cross validation
dengdifan Jan 19, 2022
6f4fdf1
allow holdout for smaller datasets
dengdifan Jan 19, 2022
8de8e50
smac4ac to smac4hpo
dengdifan Jan 19, 2022
54fbe59
maint
dengdifan Jan 21, 2022
918776f
maint
dengdifan Jan 21, 2022
8570951
allow to change decoder search space
dengdifan Jan 21, 2022
90edcfb
more resampling strategy, more options for MLP
dengdifan Jan 24, 2022
df7f568
reduced NBEATS
dengdifan Jan 24, 2022
6a22e4d
subsampler for val loader
dengdifan Jan 25, 2022
01ee1f1
rng for dataloader sampler
dengdifan Jan 25, 2022
5dcdc4e
maint
dengdifan Jan 26, 2022
f32a12b
remove generator as it cannot be pickled
dengdifan Jan 26, 2022
344e7df
allow lower fidelity to evaluate less test instances
dengdifan Jan 28, 2022
94891f2
fix dummy forecastro isues
dengdifan Jan 30, 2022
8fc51b1
maint
dengdifan Jan 30, 2022
0b34683
add gluonts as requirement
dengdifan Jan 30, 2022
f23840b
more data for val set for larger dataset
dengdifan Feb 1, 2022
dc84904
maint
dengdifan Feb 2, 2022
6466d11
Merge branch 'refactor_development_time_series' of https://github.com…
dengdifan Feb 2, 2022
6981553
maint
dengdifan Feb 3, 2022
3fda94a
fix nbeats decoder
Feb 14, 2022
d95e230
new dataset interface
dengdifan Feb 16, 2022
d185609
Merge branch 'refactor_development_time_series' of https://github.com…
dengdifan Feb 16, 2022
d5459fa
resolve conflict
dengdifan Feb 16, 2022
510cc5a
maint
dengdifan Feb 16, 2022
3806fe2
allow encoder to receive input from different sources
dengdifan Feb 16, 2022
9251bbc
multi blocks hp design
dengdifan Feb 18, 2022
5617db6
maint
dengdifan Feb 20, 2022
d04cb04
correct hp updates
dengdifan Feb 20, 2022
7881bb5
first trial on nested conjunction
dengdifan Feb 21, 2022
d7bff6e
maint
dengdifan Feb 21, 2022
2153bc2
fit for deep AR model (needs to be reverted when the issue in ConfigS…
dengdifan Feb 21, 2022
b2063e7
adjust backbones to fit new structure
dengdifan Feb 23, 2022
59cee13
further API changes
dengdifan Feb 28, 2022
b2b5580
tft temporal fusion decoder
dengdifan Feb 28, 2022
57461b9
construct network
dengdifan Mar 2, 2022
20eb852
cells for networks
dengdifan Mar 2, 2022
f5cede7
forecasting backbones
dengdifan Mar 4, 2022
50c559e
maint
dengdifan Mar 4, 2022
2dd0b11
maint
dengdifan Mar 6, 2022
0f0dbf0
move tft layer to backbone
dengdifan Mar 7, 2022
9e68629
maint
dengdifan Mar 7, 2022
ed99ba1
quantile loss
dengdifan Mar 7, 2022
45535ba
maint
dengdifan Mar 8, 2022
9fac9fe
maint
dengdifan Mar 8, 2022
75570c2
maint
dengdifan Mar 8, 2022
2f954cd
maint
dengdifan Mar 8, 2022
31f8ddc
maint
dengdifan Mar 8, 2022
7f4911e
maint
dengdifan Mar 8, 2022
2e31fdb
forecasting init configs
dengdifan Mar 8, 2022
125921c
add forbidden
dengdifan Mar 9, 2022
8d704d1
maint
dengdifan Mar 10, 2022
e646672
maint
dengdifan Mar 10, 2022
a2ad3fe
maint
dengdifan Mar 10, 2022
200691c
remove shift data
dengdifan Mar 11, 2022
538f24e
maint
dengdifan Mar 11, 2022
12ccf4b
maint
dengdifan Mar 11, 2022
4d6853d
copy dataset_properties for each refit iteration
dengdifan Mar 11, 2022
34d556a
maint and new init
dengdifan Mar 14, 2022
37501ef
Tft forecating with features (#6)
dengdifan Mar 16, 2022
5746541
fix loss computation in QuantileLoss
dengdifan Mar 16, 2022
b1fbece
fixed scaler computation
dengdifan Mar 18, 2022
683ccf5
maint
dengdifan Mar 19, 2022
95d2ab5
fix dataset
Mar 22, 2022
baaf34f
adjust window_size to seasonality
Mar 22, 2022
897cd74
maint scaling
dengdifan Mar 23, 2022
a09ddbb
fix uncorrect Seq2Seq scaling
dengdifan Mar 23, 2022
c1dda0a
fix sampling for seq2seq
dengdifan Mar 25, 2022
49ee49c
maint
dengdifan Mar 25, 2022
dc97df2
fix scaling in NBEATS
dengdifan Mar 25, 2022
399572c
move time feature computation to dataset
dengdifan Mar 28, 2022
7154308
maint
dengdifan Mar 30, 2022
1ba08fe
fix feature computation
dengdifan Mar 31, 2022
04a69d8
maint
dengdifan Mar 31, 2022
471db34
multi-variant feature validator
dengdifan Apr 13, 2022
16cf754
resolve conflicts
dengdifan Apr 13, 2022
cc77b51
maint
dengdifan Apr 13, 2022
fe6fb1f
validator for multi-variant series
dengdifan Apr 13, 2022
9264f89
feature validator
dengdifan Apr 14, 2022
aa3f7a6
multi-variant datasets
dengdifan Apr 14, 2022
974f8ff
observed targets
dengdifan Apr 14, 2022
37dd821
stucture adjustment
dengdifan Apr 20, 2022
1a6e19d
refactory ts tasks and preprocessing
dengdifan Apr 22, 2022
075c6e6
allow nan in targets
dengdifan Apr 22, 2022
2487117
preprocessing for time series
dengdifan Apr 22, 2022
86e4e3c
maint
dengdifan Apr 25, 2022
2c9944c
forecasting pipeline
dengdifan Apr 25, 2022
7eb5139
maint
dengdifan Apr 26, 2022
22fc0bc
embedding and maint
dengdifan Apr 26, 2022
1759fdf
move targets to the tail of the features
dengdifan Apr 26, 2022
9652c80
maint
dengdifan Apr 26, 2022
1d89636
static features
dengdifan Apr 27, 2022
282d63b
adjsut scaler to static features
dengdifan Apr 27, 2022
fb8b805
remove static features from forward dict
dengdifan Apr 27, 2022
533f12d
test transform
dengdifan Apr 27, 2022
f8be97c
maint
dengdifan Apr 28, 2022
e8c9071
test sets
dengdifan Apr 28, 2022
2779015
adjust dataset to allow future known features
dengdifan Apr 29, 2022
1a1fe68
maint
dengdifan Apr 29, 2022
f4ad355
maint
dengdifan Apr 29, 2022
79ef7a7
flake8
dengdifan Apr 29, 2022
88977e0
synchronise with development
dengdifan May 2, 2022
b269ff8
recover timeseries
dengdifan May 2, 2022
31f9e43
maint
dengdifan May 2, 2022
67ea836
maint
dengdifan May 2, 2022
80b8ac2
limit memory usage tae
dengdifan May 2, 2022
d01e2a7
revert test api
dengdifan May 2, 2022
3be7be9
test for targets
dengdifan May 3, 2022
77dcb7c
not allow sparse forecasting target
dengdifan May 3, 2022
6932199
test for data validator
dengdifan May 4, 2022
ee97108
test for validations
dengdifan May 5, 2022
b7f51f2
test on TimeSeriesSequence
dengdifan May 5, 2022
08bfe18
maint
dengdifan May 5, 2022
478ad68
test for resampling
dengdifan May 6, 2022
1986593
test for dataset 1
dengdifan May 7, 2022
112c876
test for datasets
dengdifan May 8, 2022
235e310
test on tae
dengdifan May 9, 2022
9d8dd0b
maint
dengdifan May 9, 2022
dc4b602
all evaluator to evalaute test sets
dengdifan May 10, 2022
e8cf8cb
tests on losses
dengdifan May 10, 2022
e5b1c47
test for metrics
dengdifan May 10, 2022
3f47489
forecasting preprocessing
dengdifan May 10, 2022
835055d
maint
dengdifan May 11, 2022
ef9e44e
finish test for preprocessing
dengdifan May 11, 2022
21b3958
test for data loader
dengdifan May 12, 2022
101ddbc
tests for dataloader
dengdifan May 13, 2022
7318086
maint
dengdifan May 13, 2022
cf2c982
test for target scaling 1
dengdifan May 13, 2022
8b7ef61
test for target scaer
dengdifan May 15, 2022
1025b93
test for training loss
dengdifan May 15, 2022
6f68633
maint
dengdifan May 16, 2022
570408d
test for network backbone
dengdifan May 16, 2022
7d42007
test for backbone base
dengdifan May 17, 2022
2033075
test for flat encoder
dengdifan May 17, 2022
c6e2239
test for seq encoder
dengdifan May 17, 2022
727e48e
test for seqencoder
dengdifan May 18, 2022
23dde67
maint
dengdifan May 18, 2022
4d9fe30
test for recurrent decoders
dengdifan May 19, 2022
eb5a7ec
test for network
dengdifan May 19, 2022
0ea372e
maint
dengdifan May 19, 2022
1b7ebbe
test for architecture
dengdifan May 20, 2022
f055fd5
test for pipelines
dengdifan May 20, 2022
ccab50e
fixed sampler
dengdifan May 21, 2022
54acaa6
maint sampler
dengdifan May 21, 2022
da6e92d
resolve conflict between embedding and net encoder
dengdifan May 21, 2022
fba012c
fix scaling
dengdifan May 21, 2022
2ed1197
allow transform for test dataloader
dengdifan May 21, 2022
95eb783
maint dataloader
dengdifan May 21, 2022
8035221
fix updates
dengdifan May 22, 2022
f3cb2de
fix dataset
dengdifan May 23, 2022
0af1217
tests on api, initial design on multi-variant
dengdifan May 24, 2022
c717fae
maint
dengdifan May 24, 2022
78d7a51
fix dataloader
dengdifan May 24, 2022
fa5cc75
move test with for loop to unittest.subtest
dengdifan May 24, 2022
2d2e039
flake 8 and update requirement
dengdifan May 24, 2022
a1c7930
mypy
dengdifan May 24, 2022
ba96c37
validator for pd dataframe
dengdifan May 27, 2022
cdcdb5a
allow series idx for api
dengdifan May 27, 2022
43671dd
maint
dengdifan May 30, 2022
806afb3
examples for forecasting
dengdifan May 30, 2022
bc80bf1
fix mypy
dengdifan May 30, 2022
c584a58
properly memory limitation for forecasting example
dengdifan May 30, 2022
0e37178
fix pre-commit
dengdifan May 30, 2022
1cf31b2
maint dataloader
dengdifan May 31, 2022
a8fa53c
remove unused auto-regressive arguments
dengdifan May 31, 2022
a8bd54d
fix pre-commit
dengdifan May 31, 2022
609ccf1
maint
dengdifan May 31, 2022
168b7cf
maint mypy
dengdifan May 31, 2022
88c2354
mypy!!!
dengdifan May 31, 2022
374cc1d
pre-commit
dengdifan May 31, 2022
4898ca5
mypyyyyyyyyyyyyyyyyyyyyyyyy
dengdifan May 31, 2022
694eebb
maint
dengdifan Jun 13, 2022
abd3900
move forcasting requirements to extras_require
dengdifan Jun 13, 2022
776aa84
bring eval_test to tae
dengdifan Jun 14, 2022
f70e2b3
make rh2epm consistent with SMAC4HPO
dengdifan Jun 14, 2022
50f6f18
remove smac4ac from smbo
dengdifan Jun 14, 2022
2663ad9
revert changes in network
dengdifan Jun 14, 2022
58eeb0c
revert changes in trainer
dengdifan Jun 14, 2022
b86908f
revert format changes
dengdifan Jun 14, 2022
68d8a25
move constant_forecasting to constatn
dengdifan Jun 14, 2022
dac5cdd
additional annotate for base pipeline
dengdifan Jun 14, 2022
7f2d394
move forecasting check to tae
dengdifan Jun 14, 2022
e43d70a
maint time series refit dataset
dengdifan Jun 14, 2022
dc48b9d
fix test
dengdifan Jun 14, 2022
1e7253a
workflow for extra requirements
dengdifan Jun 14, 2022
83e2469
docs for time series dataset
dengdifan Jun 14, 2022
1671992
fix pre-commit
dengdifan Jun 14, 2022
97d3835
docs for dataset
dengdifan Jun 14, 2022
889c5e9
maint docstring
dengdifan Jun 14, 2022
f68dc18
merge target scaler to one file
dengdifan Jun 14, 2022
dc4f510
fix forecasting init cfgs
dengdifan Jun 14, 2022
951ef4e
remove redudant pipeline configs
dengdifan Jun 14, 2022
10f0c83
maint
dengdifan Jun 14, 2022
8574c6f
SMAC4HPO instead of SMAC4AC in smbo (will be reverted further if stud…
dengdifan Jun 15, 2022
86e39bc
fixed docstrign for RNN and Transformer Decoder
dengdifan Jun 15, 2022
21fbcb2
uniformed docstrings for smbo and base task
dengdifan Jun 15, 2022
ee66c25
correct encoder to decoder in decoder.init
dengdifan Jun 15, 2022
877a124
fix doc strings
dengdifan Jun 15, 2022
1d3a74e
add license and docstrings for NBEATS heads
dengdifan Jun 16, 2022
2516859
allow memory limit to be None
dengdifan Jun 16, 2022
fe5e587
relax test load for forecasting
dengdifan Jun 16, 2022
2c6f66f
fix docs
dengdifan Jun 16, 2022
bb7f5c5
fix pre-commit
dengdifan Jun 16, 2022
9d728b5
make test compatible with py37
dengdifan Jun 17, 2022
a331093
maint docstring
dengdifan Jun 17, 2022
8a5a91b
split forecasting_eval_train_function from eval_train_function
dengdifan Jun 17, 2022
acddd22
fix namespace for test_api from train_evaluator to tae
dengdifan Jun 17, 2022
b18ce92
maint test api for forecasting
dengdifan Jun 17, 2022
0700e61
decrease number of ensemble size of test_time_series_forecasting to r…
dengdifan Jun 17, 2022
e4328ee
flatten all the prediction for forecasting pipelines
dengdifan Jun 17, 2022
b6baef1
pre-commit fix
dengdifan Jun 17, 2022
c1de20f
Merge remote-tracking branch 'upstream/development' into time_series_…
dengdifan Jun 20, 2022
0771c8e
fix docstrings and typing
dengdifan Jun 20, 2022
d066fda
maint time series dataset docstrings
dengdifan Jun 22, 2022
f701df3
maint warning message in time_series_forecasting_train_evaluator
dengdifan Jun 22, 2022
5e970f6
fix lines that are overlength
dengdifan Jun 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
construct network
  • Loading branch information
dengdifan committed Mar 2, 2022
commit 57461b903d7bd44b481094d813b8a9980dfca140
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from collections import OrderedDict
from typing import Any, Dict, Optional, Union, Tuple, List
from enum import Enum

from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import CategoricalHyperparameter, UniformIntegerHyperparameter
from ConfigSpace.conditions import EqualsCondition

import numpy as np
from abc import abstractmethod

import torch
from torch import nn
Expand All @@ -16,8 +13,6 @@
TransformedDistribution,
)

from autoPyTorch.constants import CLASSIFICATION_TASKS, STRING_TO_TASK_TYPES
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
from autoPyTorch.pipeline.components.preprocessing.time_series_preprocessing.forecasting_target_scaling. \
base_target_scaler import BaseTargetScaler
from autoPyTorch.pipeline.components.setup.network_backbone.\
Expand All @@ -28,12 +23,30 @@
NetworkStructure,
EncoderProperties
)
from autoPyTorch.pipeline.components.setup.network_backbone.forecasting_backbone.forecasting_encoder.seq_encoder.\
RNNEncoder import _RNN

from autoPyTorch.pipeline.components.setup.network_backbone.\
forecasting_backbone.forecasting_decoder.base_forecasting_decoder import (
DecoderBlockInfo,
DecoderProperties
)

from autoPyTorch.pipeline.components.setup.network_backbone.forecasting_backbone.components_util import AddLayer
from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import (
TimeDistributed, TimeDistributedInterpolation, GatedLinearUnit, ResampleNorm, AddNorm, GateAddNorm,
GatedResidualNetwork, VariableSelectionNetwork, InterpretableMultiHeadAttention
)


class EncoderOutputForm(Enum):
NoOutput = 0
HiddenStates = 1 # RNN -> RNN
Sequence = 2 # RNN/TCN/Transformer -> MLP
SequenceLast = 3 # Transformer -> Transformer



class TransformedDistribution_(TransformedDistribution):
"""
We implement the mean function such that we do not need to enquire base mean every time
Expand Down Expand Up @@ -151,7 +164,62 @@ def get_lagged_subsequences_inference(
return lagged_seq


class ForecastingNet(nn.Module):
class StackedEncoder(nn.Module):
def __init__(self,
network_structure: NetworkStructure,
has_temporal_fusion: bool,
network_encoder: Dict[str, EncoderBlockInfo],
network_decoder: Dict[str, DecoderBlockInfo],
):
self.num_blocks = network_structure.num_blocks
self.skip_connection = network_structure.skip_connection
self.has_temporal_fusion = has_temporal_fusion

self.encoder_output_type = {}
self.encoder_has_hidden_states = {}
encoder = nn.ModuleDict()
for i in range(1, self.num_blocks + 1):
block_id = f'block_{i}'
encoder[block_id] = network_encoder[block_id].encoder
if self.skip_connection:
input_size = network_encoder[block_id].encoder_output_shape_[-1]
skip_size = network_encoder[block_id].encoder_input_shape[-1]
if network_structure.skip_connection_type == 'add':
encoder[f'skip_connection_{i}'] = AddLayer(input_size, skip_size)
elif network_structure.skip_connection_type == 'gate_add_norm':
encoder[f'skip_connection_{i}'] = GateAddNorm(input_size,
hidden_size=input_size,
skip_size=skip_size,
dropout=network_structure.grn_dropout_rate)
if block_id in network_decoder:
if network_decoder[block_id].decoder_properties.recurrent:
if network_decoder[block_id].decoder_properties.has_hidden_states:
# RNN
self.encoder_has_decoder[i] = EncoderOutputForm.HiddenStates
else:
# Transformer
self.encoder_has_decoder[i] = EncoderOutputForm.Sequence
else:
self.encoder_has_decoder[i] = EncoderOutputForm.SequenceLast
else:
self.encoder_has_decoder[i] = EncoderOutputForm.NoOutput
if network_decoder[block_id].decoder_properties.has_hidden_states:
self.encoder_has_hidden_states[i] = True
else:
self.encoder_has_hidden_states[i] = False
self.encoder = encoder

def forward(self, encoder_input: torch.Tensor, additional_input: List[Optional[torch.Tensor]], output_seq: bool):
output_for_decoder = []
for i in range(1, self.num_blocks + 1):
if self.encoder_has_hidden_states[i]:
x, hx = self.encoder[f'block_{i}'](encoder_input, )





class AbstractForecastingNet(nn.Module):
future_target_required = False

def __init__(self,
Expand All @@ -163,6 +231,7 @@ def __init__(self,
window_size: int,
target_scaler: BaseTargetScaler,
dataset_properties: Dict,
auto_regressive: bool,
output_type: str = 'regression',
forecast_strategy: Optional[str] = 'mean',
num_samples: Optional[int] = 100,
Expand All @@ -180,6 +249,7 @@ def __init__(self,
network_decoder (nn.Module): network decoder
network_head (nn.Module): network head, maps the output of decoder to the final output
dataset_properties (Dict): dataset properties
auto_regressive (bool): if the overall model is auto-regressive model
encoder_properties (Dict): encoder properties
decoder_properties: (Dict): decoder properties
output_type (str): the form that the network outputs. It could be regression, distribution and
Expand Down Expand Up @@ -237,6 +307,103 @@ def __init__(self,
if self.decoder_lagged_input:
self.cached_lag_mask_decoder = None

if network_structure.variable_selection:
# TODO rewrite forecasting dataset to allow mutli-variant models!!!
first_encoder = network_encoder['block_1']
first_encoder_output_shape = network_encoder['block_1'].encoder_output_shape_[-1]
static_input_sizes = dataset_properties['static_features_shape']
variable_selector = nn.ModuleDict()
if static_input_sizes > 0:
variable_selector['static_variable_selection'] = VariableSelectionNetwork(
input_sizes=static_input_sizes,
hidden_size=first_encoder_output_shape,
input_embedding_flags={},
dropout=network_structure.grn_dropout_rate,
)
if dataset_properties['uni_variant']:
# variable selection for encoder and decoder
encoder_input_sizes = {
'past_targets': dataset_properties['input_shape'][-1],
'past_features': 0
}
decoder_input_sizes = {
'future_features': 0
}
if auto_regressive:
decoder_input_sizes.update({'future_prediction': dataset_properties['output_shape'][-1]})
else:
# TODO
pass

# create single variable grns that are shared across decoder and encoder
if network_structure.share_single_variable_networks:
variable_selector['shared_single_variable_grns'] = nn.ModuleDict()
for name, input_size in encoder_input_sizes.items():
variable_selector['shared_single_variable_grns'][name] = GatedResidualNetwork(
input_size,
min(input_size, first_encoder_output_shape),
first_encoder_output_shape,
network_structure.grn_dropout_rate,
)
for name, input_size in decoder_input_sizes.items():
if name not in self.shared_single_variable_grns:
variable_selector['shared_single_variable_grns'][name] = GatedResidualNetwork(
input_size,
min(input_size, first_encoder_output_shape),
first_encoder_output_shape,
network_structure.grn_dropout_rate,
)

variable_selector['encoder_variable_selection'] = VariableSelectionNetwork(
input_sizes=encoder_input_sizes,
hidden_size=first_encoder_output_shape,
input_embedding_flags={},
dropout=network_structure.grn_dropout_rate,
context_size=first_encoder_output_shape,
single_variable_grns={}
if not network_structure.share_single_variable_networks
else variable_selector['shared_single_variable_grns'],
)

variable_selector['encoder_variable_selection'] = VariableSelectionNetwork(
input_sizes=decoder_input_sizes,
hidden_size=self.hparams.hidden_size,
input_embedding_flags={},
dropout=network_structure.grn_dropout_rate,
context_size=first_encoder_output_shape,
single_variable_grns={}
if not network_structure.share_single_variable_networks
else variable_selector['shared_single_variable_grns'],
)

variable_selector['static_context_variable_selection'] = GatedResidualNetwork(
input_size=first_encoder_output_shape,
hidden_size=first_encoder_output_shape,
output_size=first_encoder_output_shape,
dropout=network_structure.grn_dropout_rate,
)

if first_encoder.encoder_properties.has_hidden_states:
if isinstance(first_encoder.encoder, _RNN):
# for hidden state of the rnn
variable_selector['static_context_initial_hidden_lstm'] = GatedResidualNetwork(
input_size=first_encoder_output_shape,
hidden_size=first_encoder_output_shape,
output_size=first_encoder_output_shape,
dropout=network_structure.grn_dropout_rate,
)
if first_encoder.encoder.cell_type == 'lstm':
# for cell state of the lstm
variable_selector['static_context_initial_cell_lstm'] = GatedResidualNetwork(
input_size=first_encoder_output_shape,
hidden_size=first_encoder_output_shape,
output_size=first_encoder_output_shape,
dropout=network_structure.grn_dropout_rate,
)
else:
raise NotImplementedError


@property
def device(self):
return self._device
Expand Down Expand Up @@ -280,13 +447,44 @@ def scale_value(self,
outputs = (outputs - loc.to(device)) / scale.to(device)
return outputs

@abstractmethod
def forward(self,
past_targets: torch.Tensor,
future_targets: Optional[torch.Tensor] = None,
past_features: Optional[torch.Tensor] = None,
future_features: Optional[torch.Tensor] = None,
static_features: Optional[torch.Tensor] = None,
hidden_states: Optional[Tuple[torch.Tensor]] = None):
encoder_length: Optional[torch.Tensor] = None,
decoder_observed_values: Optional[torch.Tensor] = None,
hidden_states: Optional[Tuple[torch.Tensor]] = None,
):
raise NotImplementedError

@abstractmethod
def pred_from_net_output(self, net_output):
raise NotImplementedError

@abstractmethod
def predict(self,
past_targets: torch.Tensor,
past_features: Optional[torch.Tensor] = None,
future_features: Optional[torch.Tensor] = None,
static_features: Optional[torch.Tensor] = None
):
raise NotImplementedError


class ForecastingNet(AbstractForecastingNet):
def forward(self,
past_targets: torch.Tensor,
future_targets: Optional[torch.Tensor] = None,
past_features: Optional[torch.Tensor] = None,
future_features: Optional[torch.Tensor] = None,
static_features: Optional[torch.Tensor] = None,
encoder_length: Optional[torch.Tensor] = None,
decoder_observed_values: Optional[torch.Tensor] = None,
hidden_states: Optional[Tuple[torch.Tensor]] = None,
):
if self.encoder_lagged_input:
past_targets[:, -self.window_size:], _, loc, scale = self.target_scaler(past_targets[:, -self.window_size:])
past_targets[:, :-self.window_size] = self.scale_value(past_targets[:, :-self.window_size], loc, scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def _required_fit_requirements(self):
FitRequirement("network_decoder", (Dict[str, DecoderBlockInfo]), user_defined=False,
dataset_property=False),
FitRequirement("network_head", (Optional[torch.nn.Module],), user_defined=False, dataset_property=False),
FitRequirement("auto_regressive", (bool,), user_defined=False, dataset_property=False),
FitRequirement("target_scaler", (BaseTargetScaler,), user_defined=False, dataset_property=False),
FitRequirement("required_net_out_put_type", (str,), user_defined=False, dataset_property=False),
FitRequirement("encoder_properties_1", (Dict,), user_defined=False, dataset_property=False),
]

def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent:
Expand All @@ -99,13 +99,16 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent:
network_encoder=network_encoder,
network_decoder=network_decoder,
network_head=X['network_head'],
auto_regressive=X['auto_regressive'],
window_size=X['window_size'],
dataset_properties=X['dataset_properties'],
target_scaler=X['target_scaler'],
output_type=self.net_out_type,
forecast_strategy=self.forecast_strategy,
num_samples=self.num_samples,
aggregation=self.aggregation, )
import pdb
pdb.set_trace()

if X['decoder_properties']['recurrent']:
# decoder is RNN or Transformer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ def __init__(self, input_size: int, skip_size: int):
super().__init__()
if input_size == skip_size:
self.fc = nn.Linear(skip_size, input_size)
self.norm = nn.LayerNorm(input_size)

def forward(self, input: torch.Tensor, skip: torch.Tensor):
if hasattr(self, 'fc'):
return input + self.fc(skip)
return self.norm(input + self.fc(skip))
else:
return input
return self.norm(input)


class TemporalFusionLayer(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class DecoderProperties(NamedTuple):
class DecoderBlockInfo(NamedTuple):
decoder: nn.Module
decoder_properties: DecoderProperties
decoder_output_shape: Tuple[int, ...]
decoder_input_shape: Tuple[int, ...]


class DecoderNetwork(nn.Module):
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(self,
self.config = kwargs
self.decoder: Optional[nn.Module] = None
self.n_decoder_output_features = None
self.decoder_input_shape = None
self.n_prediction_heads = 1
self.is_last_decoder = False

Expand Down Expand Up @@ -133,6 +136,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:
n_prediction_heads=self.n_prediction_heads,
dataset_properties=X['dataset_properties']
)
self.decoder_input_shape = encoder_output_shape

X['n_decoder_output_features'] = self.n_decoder_output_features
return self
Expand All @@ -150,8 +154,12 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
# 'n_prediction_heads' and 'n_decoder_output_features' are only applied to the head such that they could be
# overwritten by the following decoders
network_decoder = X.get('network_decoder', OrderedDict())
network_decoder[f'block_{self.block_number}'] = DecoderBlockInfo(decoder=self.decoder,
decoder_properties=self.decoder_properties())
network_decoder[f'block_{self.block_number}'] = DecoderBlockInfo(
decoder=self.decoder,
decoder_properties=self.decoder_properties(),
decoder_input_shape=self.decoder_input_shape,
decoder_output_shape=(self.n_prediction_heads, self.n_decoder_output_features)
)
if self.is_last_decoder:
X.update({f'network_decoder': network_decoder,
'n_prediction_heads': self.n_prediction_heads,
Expand Down
Loading