Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion functime/base/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def conformalize(
from functime.conformal import conformalize

self.fit(y=y, X=X)

y_pred = self.predict(fh=fh, X=X_future)

y_preds, y_resids = self.backtest(
y=y,
X=X,
Expand All @@ -349,7 +351,14 @@ def conformalize(
drop_short=drop_short,
drop_tolerance=drop_tolerance,
)
y_pred_qnts = conformalize(y_pred, y_preds, y_resids, alphas=alphas)

y_pred_qnts = conformalize(
y_pred=y_pred,
y_preds=y_preds,
y_resids=y_resids,
alphas=alphas,
)

if return_results:
return y_pred, y_pred_qnts, y_preds, y_resids
return y_pred_qnts
129 changes: 81 additions & 48 deletions functime/conformal.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,109 @@
from __future__ import annotations

from typing import List, Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Sequence, Union

import polars as pl


def enbpi(
y_pred: pl.LazyFrame,
y_resid: pl.LazyFrame,
alphas: List[float],
def conformalize(
*,
y_pred: Union[pl.DataFrame, pl.LazyFrame],
y_preds: Union[pl.DataFrame, pl.LazyFrame],
y_resids: Union[pl.DataFrame, pl.LazyFrame],
alphas: Optional[Sequence[float]] = None,
) -> pl.DataFrame:
"""Compute prediction intervals using ensemble batch prediction intervals (ENBPI).
Parameters
----------
y_pred : Union[pl.DataFrame, pl.LazyFrame]
The predicted values.
y_preds : Union[pl.DataFrame, pl.LazyFrame]
The predictions resulting from backtesting.
y_resids : Union[pl.DataFrame, pl.LazyFrame]
The backtesting residuals.
alphas : Optional[Sequence[float]]
The quantile levels to use for the prediction intervals. Defaults to (0.1, 0.9).
Quantiles must be two values between 0 and 1 (exclusive).
Returns
-------
pl.DataFrame
The prediction intervals.
"""
alphas = _validate_alphas(alphas)

entity_col, time_col, target_col = y_pred.columns[:3]
schema = y_pred.schema

_y_resids: pl.LazyFrame = y_resids.lazy().select(y_resids.columns[:3])
_y_preds: pl.LazyFrame = pl.concat(
[
y_pred.lazy(),
y_preds.lazy().select(
entity_col,
pl.col(time_col).cast(schema[time_col]),
pl.col(target_col).cast(schema[target_col]),
),
]
)

y_pred_quantiles = _compute_enbpi(
y_preds=_y_preds,
y_resids=_y_resids,
alphas=alphas,
)

# Make alpha base 100
y_pred_quantiles = y_pred_quantiles.with_columns(
(pl.col("quantile") * 100).cast(pl.Int16)
Copy link
Contributor

@FBruzzesi FBruzzesi Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have more than 2 decimal places, then the casting would just truncate the decimal values 🤔

Example:

import polars as pl

pl.DataFrame({"a": [0.111, 0.11]}).with_columns((pl.col("a")*100).cast(pl.Int16))
shape: (2, 1)
┌─────┐
│ a   │
│ --- │
│ i16 │
╞═════╡
│ 11  │
│ 11  │
└─────┘

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct; I am still to change that bit of the code fortunately. This column won't exist once I'm done ✔️

)

return y_pred_quantiles


def _compute_enbpi(
*,
y_preds: pl.LazyFrame,
y_resids: pl.LazyFrame,
alphas: Sequence[float],
) -> pl.DataFrame:
"""Compute prediction intervals using ensemble batch prediction intervals (ENBPI)."""

# 1. Group residuals by entity
entity_col, time_col = y_pred.columns[:2]
y_resid = y_resid.collect()
entity_col, time_col = y_preds.columns[:2]

# 2. Forecast future prediction intervals: use constant residual quantile
schema = y_pred.schema
schema = y_preds.schema
y_pred_qnts = []
for alpha in alphas:
y_pred_qnt = y_pred.join(
y_resid.group_by(entity_col)
.agg(pl.col(y_resid.columns[-1]).quantile(alpha).alias("score"))
y_pred_qnt = y_preds.join(
y_resids.group_by(entity_col)
.agg(pl.col(y_resids.columns[-1]).quantile(alpha).alias("score"))
.lazy(),
how="left",
on=entity_col,
).select(
[
pl.col(entity_col).cast(schema[entity_col]),
pl.col(time_col).cast(schema[time_col]),
pl.col(y_pred.columns[-1]) + pl.col("score"),
pl.col(y_preds.columns[-1]) + pl.col("score"),
pl.lit(alpha).alias("quantile"),
]
)
y_pred_qnts.append(y_pred_qnt)

y_pred_qnts = pl.concat(y_pred_qnts).sort([entity_col, time_col]).collect()
return y_pred_qnts


def conformalize(
y_pred: pl.DataFrame,
y_preds: pl.DataFrame,
y_resids: pl.DataFrame,
alphas: Optional[List[float]] = None,
) -> pl.DataFrame:
"""Compute prediction intervals using ensemble batch prediction intervals (ENBPI)."""

alphas = alphas or [0.1, 0.9]
entity_col, time_col, target_col = y_pred.columns[:3]
schema = y_pred.schema
y_preds = pl.concat(
[
y_pred,
y_preds.select(
[
entity_col,
pl.col(time_col).cast(schema[time_col]),
pl.col(target_col).cast(schema[target_col]),
]
),
]
)

y_preds = y_preds.lazy()
y_resids = y_resids.select(y_resids.columns[:3]).lazy()
y_pred_quantiles = enbpi(y_preds, y_resids, alphas)
return pl.concat(y_pred_qnts).sort([entity_col, time_col]).collect()

# Make alpha base 100
y_pred_quantiles = y_pred_quantiles.with_columns(
(pl.col("quantile") * 100).cast(pl.Int16)
)

return y_pred_quantiles
def _validate_alphas(alphas: Optional[Sequence[float]]) -> Sequence[float]:
if alphas is None:
return (0.1, 0.9)
elif len(alphas) != 2:
raise ValueError("alphas must be a list of length 2")
elif not all(0 < alpha < 1 for alpha in alphas):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check that the sequence is sorted?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC the function returns the alphas sorted, so that should be OK. Am I correct?

raise ValueError("alphas must be between 0 and 1")
return alphas
25 changes: 25 additions & 0 deletions tests/test_conformal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

import pytest

from functime.conformal import _validate_alphas


def test_validate_alphas():
assert _validate_alphas(None) == (0.1, 0.9)
assert _validate_alphas([0.1, 0.9]) == [0.1, 0.9]
assert _validate_alphas([0.1, 0.5]) == [0.1, 0.5]
assert _validate_alphas([0.5, 0.9]) == [0.5, 0.9]

with pytest.raises(ValueError):
_validate_alphas([0.1, 0.5, 0.9, 0.2])
with pytest.raises(ValueError):
_validate_alphas([0.1, 0.5, 0.9, 0.2, 0.3])
with pytest.raises(ValueError):
_validate_alphas([0.1, -0.5])
with pytest.raises(ValueError):
_validate_alphas([0.1, 1.5])
with pytest.raises(ValueError):
_validate_alphas([-0.1, 0.5])
with pytest.raises(ValueError):
_validate_alphas([1.1, 0.5])