Skip to content

Imp/tsmixer basic #2555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e12a751
Add tsmixer-basic
eschibli Sep 3, 2024
1bd17da
Implimented project_first
eschibli Sep 8, 2024
34e431f
Merge branch 'master' of https://github.com/unit8co/darts
eschibli Oct 8, 2024
bb72922
Updated changelog
eschibli Oct 8, 2024
1934928
Linting
eschibli Oct 9, 2024
c5de724
Additional linting
eschibli Oct 9, 2024
b8a2e88
Further updated changelog
eschibli Oct 11, 2024
c6c7e97
More linting?
eschibli Oct 11, 2024
79612f0
Removed unnecessary layer init
eschibli Oct 23, 2024
6cdc9db
linting
eschibli Oct 23, 2024
171dc34
Reverted example
eschibli Oct 24, 2024
f9797a3
linting????
eschibli Oct 24, 2024
e528f85
auto formatting
eschibli Oct 25, 2024
133547e
Merge branch 'master' of https://github.com/unit8co/darts into Imp/ts…
eschibli Oct 25, 2024
5afff01
Merge branch 'master' of https://github.com/unit8co/darts into Imp/ts…
eschibli Oct 27, 2024
3a392a3
Added test
eschibli Oct 27, 2024
ebe02d1
Improved test coverage
eschibli Oct 28, 2024
0a90f24
Docustring tweak
eschibli Oct 31, 2024
2674e1c
Merge branch 'master' of https://github.com/unit8co/darts into Imp/ts…
eschibli Nov 3, 2024
2bf09ce
Merge branch 'master' into Imp/tsmixer-basic
dennisbader Nov 7, 2024
245ae09
Merge branch 'master' into Imp/tsmixer-basic
madtoinou Nov 12, 2024
f9c0d15
Merge branch 'master' into Imp/tsmixer-basic
eschibli Dec 17, 2024
51e2b11
Merge branch 'master' into Imp/tsmixer-basic
eschibli Feb 11, 2025
c86b559
Merge branch 'master' into Imp/tsmixer-basic
madtoinou Mar 5, 2025
e6647c0
Merge branch 'unit8co:master' into Imp/tsmixer-basic
eschibli Mar 10, 2025
1483096
Merge branch 'master' into Imp/tsmixer-basic
eschibli Mar 25, 2025
97c83b2
Added project_after_n_layers to tsmixer
eschibli Mar 25, 2025
bf912d9
Merge branch 'unit8co:master' into Imp/tsmixer-basic
eschibli Apr 19, 2025
46cf888
Relaxed TSMixer performance theshold
eschibli Apr 20, 2025
c374724
Try again
eschibli Apr 20, 2025
67a89f1
Merge branch 'master' into Imp/tsmixer-basic
dennisbader Apr 20, 2025
ef1a322
Try again
eschibli Apr 21, 2025
0dd0c30
Corrected changelog
eschibli Apr 25, 2025
a4891e2
Added additional test
eschibli Apr 29, 2025
a46645a
Merge branch 'master' into Imp/tsmixer-basic
eschibli Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- 🟠 Renamed `RandomForest` to `RandomForestModel`. Using `RandomForest` will raise a depraction warning.
- 🔴 Renamed `RegressionModelWithCategoricalCovariates` to `SKLearnModelWithCategoricalCovariates`. Removed `RegressionModelWithCategoricalCovariates`

- Added `project_after_n_blocks` hyperparam to `TSMixerModel`, allowing some or all of the backbone to operate in the lookback rather than forecasted time space by [Eric Schibli](https://github.com/eschibli)

**Fixed**

- Fixed some issues in `NLinearModel` with `normalize=True` that resulted in decreased predictive accuracy. Using `shared_weights=True` and auto-regressive forecasting now work properly. [#2757](https://github.com/unit8co/darts/pull/2757) by [Timon Erhart](https://github.com/turbotimon).
Expand Down
148 changes: 108 additions & 40 deletions darts/models/forecasting/tsmixer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def __init__(
super().__init__()

mixing_input = input_dim

if static_cov_dim != 0:
self.feature_mixing_static = _FeatureMixing(
sequence_length=sequence_length,
Expand Down Expand Up @@ -317,28 +318,30 @@ def __init__(
self,
input_dim: int,
output_dim: int,
num_encoder_blocks: int,
num_decoder_blocks: int,
past_cov_dim: int,
future_cov_dim: int,
static_cov_dim: int,
nr_params: int,
hidden_size: int,
ff_size: int,
num_blocks: int,
activation: str,
dropout: float,
norm_type: Union[str, nn.Module],
normalize_before: bool,
**kwargs,
) -> None:
"""
Initializes the TSMixer module for use within a Darts forecasting model.
Initializes the TSMixer module.

Parameters
----------
input_dim
Number of input target features.
output_dim
Number of output target features.

past_cov_dim
Number of past covariate features.
future_cov_dim
Expand All @@ -352,8 +355,6 @@ def __init__(
Hidden state size of the TSMixer.
ff_size
Dimension of the feedforward network internal to the module.
num_blocks
Number of mixer blocks.
activation
Activation function to use.
dropout
Expand All @@ -362,6 +363,11 @@ def __init__(
Type of normalization to use.
normalize_before
Whether to apply normalization before or after mixing.
project_first_layer
Whether to project to the output time dimension at the first layer (default),
or at the end of the module. False is recommended if there are
no future covariates, while True is recommended if there are
important future covariates.
"""
super().__init__(**kwargs)
self.input_dim = input_dim
Expand All @@ -373,7 +379,7 @@ def __init__(
if activation not in ACTIVATIONS:
raise_log(
ValueError(
f"Invalid `activation={activation}`. Must be on of {ACTIVATIONS}."
f"Invalid `activation={activation}`. Must be one of {ACTIVATIONS}."
),
logger=logger,
)
Expand All @@ -383,7 +389,7 @@ def __init__(
if norm_type not in NORMS:
raise_log(
ValueError(
f"Invalid `norm_type={norm_type}`. Must be on of {NORMS}."
f"Invalid `norm_type={norm_type}`. Must be one of {NORMS}."
),
logger=logger,
)
Expand All @@ -402,13 +408,17 @@ def __init__(
"normalize_before": normalize_before,
}

# Projects from the input time dimension to the output time dimension
self.fc_hist = nn.Linear(self.input_chunk_length, self.output_chunk_length)

self.feature_mixing_hist = _FeatureMixing(
sequence_length=self.output_chunk_length,
sequence_length=self.input_chunk_length,
input_dim=input_dim + past_cov_dim + future_cov_dim,
output_dim=hidden_size,
**mixer_params,
)

# Process future covariates in decoder (if exists)
if future_cov_dim:
self.feature_mixing_future = _FeatureMixing(
sequence_length=self.output_chunk_length,
Expand All @@ -418,19 +428,40 @@ def __init__(
)
else:
self.feature_mixing_future = None
self.conditional_mixer = self._build_mixer(
prediction_length=self.output_chunk_length,
num_blocks=num_blocks,

# Remove previous fc_hist and fc_future.
# New projection from encoder (input_chunk_length) to decoder (output_chunk_length)
self.encoder_to_decoder = nn.Linear(
self.input_chunk_length, self.output_chunk_length
)

# Build encoder mixer (operating on input_chunk_length)
self.encoder_mixer = self._build_mixer(
sequence_length=self.input_chunk_length,
num_blocks=num_encoder_blocks,
hidden_size=hidden_size,
future_cov_dim=0, # encoder mixing uses only historical features
static_cov_dim=static_cov_dim,
**mixer_params,
)
# Build decoder mixer (operating on output_chunk_length)
self.decoder_mixer = self._build_mixer(
sequence_length=self.output_chunk_length,
num_blocks=num_decoder_blocks,
hidden_size=hidden_size,
future_cov_dim=future_cov_dim,
static_cov_dim=static_cov_dim,
**mixer_params,
)
self.fc_out = nn.Linear(hidden_size, output_dim * nr_params)

self.fc_out = nn.Linear(
hidden_size * (1 + int((num_decoder_blocks == 0) and (future_cov_dim > 0))),
output_dim * nr_params,
)

@staticmethod
def _build_mixer(
prediction_length: int,
sequence_length: int,
num_blocks: int,
hidden_size: int,
future_cov_dim: int,
Expand All @@ -441,14 +472,15 @@ def _build_mixer(
# the first block takes `x` consisting of concatenated features with size `hidden_size`:
# - historic features
# - optional future features
input_dim_block = hidden_size * (1 + int(future_cov_dim > 0))

input_dim_block = hidden_size * (
1 + int(future_cov_dim > 0)
) # starting dimension for mixer layers
mixer_layers = nn.ModuleList()
for _ in range(num_blocks):
layer = _ConditionalMixerLayer(
input_dim=input_dim_block,
output_dim=hidden_size,
sequence_length=prediction_length,
sequence_length=sequence_length,
static_cov_dim=static_cov_dim,
**kwargs,
)
Expand Down Expand Up @@ -480,45 +512,60 @@ def forward(
# B: batch size
# L: input chunk length
# T: output chunk length
# SL: Residual block time dimension (T if project_first_layer, L otherwise)
# C: target components
# P: past cov features
# F: future cov features
# S: static cov features
# H = C + P + F: historic features
# H_S: hidden Size
# N_P: likelihood parameters
# N_P: number of samples to predict

# `x`: (B, L, H), `x_future`: (B, T, F), `x_static`: (B, C or 1, S)
x, x_future, x_static = x_in

# swap feature and time dimensions (B, L, H) -> (B, H, L)
x = _time_to_feature(x)
# linear transformations to horizon (B, H, L) -> (B, H, T)
x = self.fc_hist(x)
# (B, H, T) -> (B, T, H)
x = _time_to_feature(x)
if self.static_cov_dim:
# (B, C, S) -> (B, 1, C * S)
x_static_hist = x_static.reshape(x_static.shape[0], 1, -1)
# repeat to match lookback time dim: (B, 1, C * S) -> (B, L, C * S)
x_static_hist = x_static_hist.repeat(1, self.input_chunk_length, 1)

# (B, C, S) -> (B, 1, C * S)
x_static_future = x_static.reshape(x_static.shape[0], 1, -1)
# repeat to match horizon time dim: (B, 1, C * S) -> (B, T, C * S)
x_static_future = x_static_future.repeat(1, self.output_chunk_length, 1)
else:
x_static_hist = None
x_static_future = None

# feature mixing for historical features (B, T, H) -> (B, T, H_S)
# Process historical data (B, L, H) -> (B, L, H_S)
x = self.feature_mixing_hist(x)

# Process future data (B, T, F) -> (B, T, H_S)
if self.future_cov_dim:
# feature mixing for future features (B, T, F) -> (B, T, H_S)
x_future = self.feature_mixing_future(x_future)
# (B, T, H_S) + (B, T, H_S) -> (B, T, 2*H_S)
x = torch.cat([x, x_future], dim=-1)

if self.static_cov_dim:
# (B, C, S) -> (B, 1, C * S)
x_static = x_static.reshape(x_static.shape[0], 1, -1)
# repeat to match horizon (B, 1, C * S) -> (B, T, C * S)
x_static = x_static.repeat(1, self.output_chunk_length, 1)
# Apply encoder mixer layers
for layer in self.encoder_mixer:
x = layer(x, x_static_hist)

# Project time dimension (B, L, H_S) -> (B, T, H_S)
x = x.transpose(1, 2)
x = self.encoder_to_decoder(x) # Linear map
x = x.transpose(1, 2)

# If future covariates are provided, mix and concatenate them with the encoder output
if self.future_cov_dim:
# (B, T, H_S) -> (B, T, 2 * H_S)
x = torch.cat([x, x_future], dim=-1)

for mixing_layer in self.conditional_mixer:
# conditional mixer layers with static covariates (B, T, 2 * H_S), (B, T, C * S) -> (B, T, H_S)
x = mixing_layer(x, x_static=x_static)
# Apply decoder mixer layers
for layer in self.decoder_mixer:
x = layer(x, x_static_future)

# linear transformation to generate the forecast (B, T, H_S) -> (B, T, C * N_P)
# Forecast generation
# (B, T, H_S) -> (B, T, C * N_P)
x = self.fc_out(x)
# (B, T, C * N_P) -> (B, T, C, N_P)
x = x.view(-1, self.output_chunk_length, self.output_dim, self.nr_params)
return x

Expand All @@ -537,6 +584,7 @@ def __init__(
norm_type: Union[str, nn.Module] = "LayerNorm",
normalize_before: bool = False,
use_static_covariates: bool = True,
project_after_n_blocks: int = 0,
**kwargs,
) -> None:
"""Time-Series Mixer (TSMixer): An All-MLP Architecture for Time Series.
Expand Down Expand Up @@ -578,8 +626,10 @@ def __init__(
The hidden state size / size of the second feed-forward layer in the feature mixing MLP.
ff_size
The size of the first feed-forward layer in the feature mixing MLP.
num_blocks
The number of mixer blocks in the model. The number includes the first block and all subsequent blocks.
num_encoder_blocks
The number of mixer blocks in the encoder.
num_decoder_blocks
The number of mixer blocks in the decoder.
activation
The activation function to use in the mixer layers (default='ReLU').
Supported activations: ['ReLU', 'RReLU', 'PReLU', 'ELU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid',
Expand Down Expand Up @@ -770,11 +820,12 @@ def encode_year(idx):
# Model specific parameters
self.ff_size = ff_size
self.dropout = dropout
self.num_blocks = num_blocks
self.activation = activation
self.normalize_before = normalize_before
self.norm_type = norm_type
self.hidden_size = hidden_size
self.num_blocks = num_blocks
self.project_after_n_blocks = project_after_n_blocks
self._considers_static_covariates = use_static_covariates

def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Module:
Expand All @@ -801,6 +852,22 @@ def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Modu
input_dim = past_target.shape[1]
output_dim = future_target.shape[1]

num_encoder_blocks = self.project_after_n_blocks
num_decoder_blocks = self.num_blocks - num_encoder_blocks

# Raise exception for nonsensical number of encoder and decoder blocks
if (
num_encoder_blocks < 0
or num_decoder_blocks < 0
or (num_encoder_blocks + num_decoder_blocks != self.num_blocks)
):
raise_log(
ValueError(
f"Invalid number of encoder and decoder blocks. "
f"project_after_n_blocks must be between 0 and {self.num_blocks} inclusive."
),
)

static_cov_dim = (
static_covariates.shape[0] * static_covariates.shape[1]
if static_covariates is not None
Expand All @@ -821,7 +888,8 @@ def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Modu
nr_params=nr_params,
hidden_size=self.hidden_size,
ff_size=self.ff_size,
num_blocks=self.num_blocks,
num_encoder_blocks=num_encoder_blocks,
num_decoder_blocks=num_decoder_blocks,
activation=self.activation,
dropout=self.dropout,
norm_type=self.norm_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
"n_epochs": 10,
"pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
},
60.0,
75.0,
),
(
GlobalNaiveAggregate,
Expand Down
36 changes: 36 additions & 0 deletions darts/tests/models/forecasting/test_tsmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,39 @@ def test_time_batch_norm_2d_gradients(self):
output.mean().backward()

assert input_tensor.grad is not None

@pytest.mark.parametrize("project_after_n_blocks", [-1, 0, 1, 2, 3])
def test_project_after_n_blocks(self, project_after_n_blocks):
ts = tg.sine_timeseries(length=36, freq="h")
input_len = 12
output_len = 6

expect_exception = project_after_n_blocks == -1 or project_after_n_blocks == 3

if expect_exception:
with pytest.raises(ValueError):
model = TSMixerModel(
input_chunk_length=input_len,
output_chunk_length=output_len,
n_epochs=1,
project_after_n_blocks=project_after_n_blocks,
**tfm_kwargs,
)
model.fit(ts)

else:
model = TSMixerModel(
input_chunk_length=input_len,
output_chunk_length=output_len,
n_epochs=1,
project_after_n_blocks=project_after_n_blocks,
# Cover case of projecting future covs back to input dims
add_encoders={"cyclic": {"future": "hour"}},
**tfm_kwargs,
)
model.fit(ts)
model.predict(n=output_len, series=ts)

# Assert that the encoder and decoder mixers have the expected number of blocks
assert len(model.model.decoder_mixer) == 2 - project_after_n_blocks
assert len(model.model.encoder_mixer) == project_after_n_blocks
Loading