Skip to content

Commit 63495c3

Browse files
authored
fix: unpack fcsts from batches for chronos 2 (#253)
2 parents e2cc558 + 89a6f6c commit 63495c3

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

tests/models/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def disable_mps_session(monkeypatch):
6262
ADIDA(),
6363
Prophet(),
6464
Chronos(repo_id="amazon/chronos-bolt-tiny", alias="Chronos-Bolt"),
65-
Chronos(repo_id="s3://autogluon/chronos-2", alias="Chronos-2"),
65+
Chronos(repo_id="amazon/chronos-2", alias="Chronos-2"),
66+
Chronos(repo_id="amazon/chronos-2", alias="Chronos-2", batch_size=2),
6667
FlowState(repo_id="ibm-research/flowstate"),
6768
FlowState(
6869
repo_id="ibm-granite/granite-timeseries-flowstate-r1",

timecopilot/models/foundation/chronos.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
5252
| Model ID | Parameters |
5353
| ---------------------------------------------------------------------- | ---------- |
54-
| [`s3://autogluon/chronos-2`](https://arxiv.org/abs/2510.15821) | 120M |
54+
| [`amazon/chronos-2`](https://huggingface.co/amazon/chronos-2) | 120M |
5555
| [`amazon/chronos-bolt-tiny`](https://huggingface.co/amazon/chronos-bolt-tiny) | 9M |
5656
| [`amazon/chronos-bolt-mini`](https://huggingface.co/amazon/chronos-bolt-mini) | 21M |
5757
| [`amazon/chronos-bolt-small`](https://huggingface.co/amazon/chronos-bolt-small) | 48M |
@@ -118,8 +118,12 @@ def _predict(
118118
] # list of tuples
119119
fcsts_quantiles, fcsts_mean = zip(*fcsts, strict=False)
120120
if isinstance(model, Chronos2Pipeline):
121-
fcsts_mean = fcsts_mean[0]
122-
fcsts_quantiles = fcsts_quantiles[0]
121+
fcsts_mean = [f_mean for fcst in fcsts_mean for f_mean in fcst] # type: ignore
122+
fcsts_quantiles = [
123+
f_quantile
124+
for fcst in fcsts_quantiles
125+
for f_quantile in fcst # type: ignore
126+
]
123127
fcsts_mean_np = torch.cat(fcsts_mean).numpy()
124128
fcsts_quantiles_np = torch.cat(fcsts_quantiles).numpy()
125129
else:
@@ -131,7 +135,7 @@ def _predict(
131135
for batch in tqdm(dataset)
132136
]
133137
if isinstance(model, Chronos2Pipeline):
134-
fcsts = fcsts[0]
138+
fcsts = [f_fcst for fcst in fcsts for f_fcst in fcst] # type: ignore
135139
fcsts = torch.cat(fcsts)
136140
if isinstance(model, ChronosPipeline):
137141
# for t5 models, `predict` returns a tensor of shape

0 commit comments

Comments
 (0)