Skip to content

Commit 1eefe9d

Browse files
committed
added rolling average on read
1 parent 70ade88 commit 1eefe9d

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

src/anemoi/datasets/data/dataset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,12 @@ def __subset(self, **kwargs: Any) -> "Dataset":
293293
if skip_missing_dates:
294294
return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate()
295295

296+
if "rolling_average" in kwargs:
297+
from .rolling_average import RollingAverage
298+
299+
rolling_average = kwargs.pop("rolling_average")
300+
return RollingAverage(self, rolling_average)._subset(**kwargs).mutate()
301+
296302
if "interpolate_frequency" in kwargs:
297303
from .interpolate import InterpolateFrequency
298304

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# (C) Copyright 2025 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
11+
import logging
12+
from functools import cached_property
13+
from typing import Any
14+
15+
import numpy as np
16+
from numpy.typing import NDArray
17+
18+
from anemoi.datasets.data.indexing import expand_list_indexing
19+
20+
from .dataset import Dataset
21+
from .dataset import FullIndex
22+
from .debug import Node
23+
from .debug import debug_indexing
24+
from .forwards import Forwards
25+
26+
LOG = logging.getLogger(__name__)
27+
28+
29+
class RollingAverage(Forwards):
30+
"""A class to represent a dataset with interpolated frequency."""
31+
32+
def __init__(self, dataset: Dataset, window: str | tuple[int, int, str]) -> None:
33+
"""Initialize the RollingAverage class.
34+
35+
Parameters
36+
----------
37+
dataset : Dataset
38+
The dataset to be averaged with a rolling window.
39+
window : (int, int, str)
40+
The rolling average window (start, end, 'freq').
41+
'freq' means the window is in number of time steps in the dataset.
42+
Both start and end are inclusive, i.e. window = (-2, 2, 'freq') means a window of 5 time steps.
43+
For now, only 'freq' is supported, in the future other units may be supported.
44+
Windows such as "[-2h, +2h]" are not supported yet.
45+
"""
46+
super().__init__(dataset)
47+
if not (isinstance(window, (list, tuple)) and len(window) == 3):
48+
raise ValueError(f"Window must be (int, int, str), got {window}")
49+
if not isinstance(window[0], int) or not isinstance(window[1], int) or not isinstance(window[2], str):
50+
raise ValueError(f"Window must be (int, int, str), got {window}")
51+
if window[2] not in ["freq", "frequency"]:
52+
raise NotImplementedError(f"Window must be (int, int, 'freq'), got {window}")
53+
54+
# window = (0, 0, 'freq') means no change
55+
self.i_start = -window[0]
56+
self.i_end = window[1] + 1
57+
if self.i_start <= 0:
58+
raise ValueError(f"Window start must be negative, got {window}")
59+
if self.i_end <= 0:
60+
raise ValueError(f"Window end must be positive, got {window}")
61+
62+
self.window_str = f"-{self.i_start}-to-{self.i_end}"
63+
64+
@property
65+
def shape(self):
66+
shape = list(self.forward.shape)
67+
shape[0] = len(self)
68+
return tuple(shape)
69+
70+
@debug_indexing
71+
@expand_list_indexing
72+
def __getitem__(self, n: FullIndex) -> NDArray[Any]:
73+
def f(array):
74+
return np.nanmean(array, axis=0)
75+
76+
if isinstance(n, slice):
77+
n = (n,)
78+
79+
if isinstance(n, tuple):
80+
first = n[0]
81+
if len(n) > 1:
82+
rest = n[1:]
83+
else:
84+
rest = ()
85+
86+
if isinstance(first, int):
87+
slice_ = slice(first, first + self.i_start + self.i_end)
88+
data = self.forward[(slice_,) + rest]
89+
return f(data)
90+
91+
if isinstance(first, slice):
92+
first = list(range(first.start or 0, first.stop or len(self), first.step or 1))
93+
94+
if isinstance(first, (list, tuple)):
95+
first = [i if i >= 0 else len(self) + i for i in first]
96+
if any(i >= len(self) for i in first):
97+
raise IndexError(f"Index out of range: {first}")
98+
slices = [slice(i, i + self.i_start + self.i_end) for i in first]
99+
data = [self.forward[(slice_,) + rest] for slice_ in slices]
100+
res = [f(d) for d in data]
101+
return np.array(res)
102+
103+
assert False, f"Expected int, slice, list or tuple as first element of tuple, got {type(first)}"
104+
105+
assert isinstance(n, int), f"Expected int, slice, tuple, got {type(n)}"
106+
107+
if n < 0:
108+
n = len(self) + n
109+
if n >= len(self):
110+
raise IndexError(f"Index out of range: {n}")
111+
112+
slice_ = slice(n, n + self.i_start + self.i_end)
113+
data = self.forward[slice_]
114+
return f(data)
115+
116+
def __len__(self) -> int:
117+
return len(self.forward) - (self.i_end + self.i_start - 1)
118+
119+
@cached_property
120+
def dates(self) -> NDArray[np.datetime64]:
121+
"""Get the interpolated dates."""
122+
dates = self.forward.dates
123+
return dates[self.i_start : len(dates) - self.i_end + 1]
124+
125+
def tree(self) -> Node:
126+
return Node(self, [self.forward.tree()], window=self.window_str)
127+
128+
@cached_property
129+
def missing(self) -> set[int]:
130+
"""Get the missing data indices."""
131+
result = []
132+
133+
for i in self.forward.missing:
134+
for j in range(0, self.i_end + self.i_start):
135+
result.append(i + j)
136+
137+
result = {x for x in result if x < self._len}
138+
return result
139+
140+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
141+
return {}

0 commit comments

Comments
 (0)