Skip to content

[BUG] Correcly set lagged variables to known when lag >= horizon #1910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytorch_forecasting/data/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,10 +779,10 @@ def _append_if_new(lst, x):
if name in var_names:
for lagged_name, lag in lagged_names.items():
# if lag is longer than horizon, lagged var becomes future-known
if known or lag < self.max_prediction_length:
_append_if_new(var_names, lagged_name)
elif lag < self.max_prediction_length:
if known == "known" or lag >= self.max_prediction_length:
_append_if_new(_attr(realcat, "known"), lagged_name)
else:
_append_if_new(_attr(realcat, "unknown"), lagged_name)

@property
def dropout_categoricals(self) -> list[str]:
Expand Down
79 changes: 79 additions & 0 deletions tests/test_data/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,85 @@ def test_lagged_variables(test_data, kwargs):
).all(), "lagged target must be the same as non-lagged target"


def test_lagged_variable_known_unknown_assignment(test_data):
"""
Test that lagged variables are assigned to known or unknown variables correctly:
- If lag < max_prediction_length: lagged variable is unknown
- If lag >= max_prediction_length: lagged variable is known
"""
# Setup: one known real, one unknown real, one known cat, one unknown cat
dataset = TimeSeriesDataSet(
test_data.copy(),
time_idx="time_idx",
target="volume",
group_ids=["agency", "sku"],
max_encoder_length=5,
max_prediction_length=2,
min_prediction_length=1,
min_encoder_length=3,
time_varying_unknown_reals=["volume"],
time_varying_known_categoricals=["month"],
time_varying_unknown_categoricals=["agency"],
lags={"volume": [1, 2, 3], "agency": [1, 2, 3], "month": [1, 2, 3]},
)

horizon = dataset.max_prediction_length

for var in ["volume"]:
is_known = var in dataset._time_varying_known_reals
for lag in [1, 2, 3]:
lagged_name = f"{var}_lagged_by_{lag}"
if is_known:
assert (
lagged_name in dataset._time_varying_known_reals
), f"{lagged_name} should be known real (from known real)"
assert (
lagged_name not in dataset._time_varying_unknown_reals
), f"{lagged_name} should not be unknown real (from known real)"
else:
if lag >= horizon:
assert (
lagged_name in dataset._time_varying_known_reals
), f"{lagged_name} should be known real (lag >= horizon)"
assert (
lagged_name not in dataset._time_varying_unknown_reals
), f"{lagged_name} should not be unknown real (lag >= horizon)"
else:
assert (
lagged_name in dataset._time_varying_unknown_reals
), f"{lagged_name} should be unknown real (lag < horizon)"
assert (
lagged_name not in dataset._time_varying_known_reals
), f"{lagged_name} should not be known real (lag < horizon)"

for var in ["agency", "month"]:
is_known = var in dataset._time_varying_known_categoricals
for lag in [1, 2, 3]:
lagged_name = f"{var}_lagged_by_{lag}"
if is_known:
assert (
lagged_name in dataset._time_varying_known_categoricals
), f"{lagged_name} should be known cat (from known cat)"
assert (
lagged_name not in dataset._time_varying_unknown_categoricals
), f"{lagged_name} should not be unknown cat (from known cat)"
else:
if lag >= horizon:
assert (
lagged_name in dataset._time_varying_known_categoricals
), f"{lagged_name} should be known cat (lag >= horizon)"
assert (
lagged_name not in dataset._time_varying_unknown_categoricals
), f"{lagged_name} should not be unknown cat (lag >= horizon)"
else:
assert (
lagged_name in dataset._time_varying_unknown_categoricals
), f"{lagged_name} should be unknown cat (lag < horizon)"
assert (
lagged_name not in dataset._time_varying_known_categoricals
), f"{lagged_name} should not be known cat (lag < horizon)"


@pytest.mark.parametrize(
"agency,first_prediction_idx,should_raise",
[
Expand Down
Loading