Skip to content

drop_incomplete support in SeasonGrouper #10436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 108 additions & 19 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"EncodedGroups",
"Grouper",
"Resampler",
"SeasonGrouper",
"SeasonResampler",
"TimeResampler",
"UniqueGrouper",
]
Expand Down Expand Up @@ -237,7 +239,7 @@ def _factorize_given_labels(self, group: T_Group) -> EncodedGroups:
)
return EncodedGroups(
codes=codes,
full_index=pd.Index(self.labels), # type: ignore[arg-type]
full_index=pd.Index(self.labels),
unique_coord=Variable(
dims=codes.name,
data=self.labels,
Expand Down Expand Up @@ -611,6 +613,47 @@ def unique_value_groups(
return values, inverse


def _adjust_years_for_season(
years: np.ndarray,
months: np.ndarray,
season_tuple: tuple[int, ...],
season_str: str,
) -> np.ndarray:
"""
Adjust years for seasons that span December and January (e.g., DJF).

For seasons like DJF, January and February should be considered part of the
winter that started in the previous December.

Parameters
----------
years : np.ndarray
Array of years corresponding to each timestamp
months : np.ndarray
Array of months corresponding to each timestamp
season_tuple : tuple of int
Tuple of month numbers that make up the season
season_str : str
String representation of the season (e.g., "DJF")

Returns
-------
np.ndarray
Adjusted years array
"""
year_adjusted = years.copy()
# Handle seasons that contain December followed by other months
if "D" in season_str and 12 in season_tuple:
# Find the position of "D" (December) in the season string
d_index = season_str.index("D")
# Get all months that come after December in the season
months_after_dec = season_tuple[d_index + 1 :]
# Reduce year by 1 for months that come after December
for month_num in months_after_dec:
year_adjusted[months == month_num] -= 1
return year_adjusted


def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]:
"""
>>> season_to_month_tuple(["DJF", "MAM", "JJA", "SON"])
Expand Down Expand Up @@ -741,25 +784,30 @@ class SeasonGrouper(Grouper):
seasons: sequence of str
List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc.
Overlapping seasons are allowed (e.g. ``["DJFM", "MAMJ", "JJAS", "SOND"]``)
drop_incomplete: bool, default: False
Whether to drop seasons that are not completely included in the data.
For example, if a time series starts in Jan-2001, and seasons includes `"DJF"`
then observations from Jan-2001, and Feb-2001 are ignored in the grouping
since Dec-2000 isn't present. This check is performed for each year.

Examples
--------
>>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"])
SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND'])
SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND'], drop_incomplete=False)

The ordering is preserved

>>> SeasonGrouper(["MAM", "JJAS", "OND", "JF"])
SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF'])
SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF'], drop_incomplete=False)

Overlapping seasons are allowed

>>> SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"])
SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND'])
SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND'], drop_incomplete=False)
"""

seasons: Sequence[str]
# drop_incomplete: bool = field(default=True) # TODO
drop_incomplete: bool = field(default=False, kw_only=True)

def factorize(self, group: T_Group) -> EncodedGroups:
if TYPE_CHECKING:
Expand All @@ -771,15 +819,43 @@ def factorize(self, group: T_Group) -> EncodedGroups:
months = group.dt.month.data
seasons_groups = find_independent_seasons(self.seasons)
codes_ = np.full((len(seasons_groups),) + group.shape, -1, dtype=np.int8)
group_indices: list[list[int]] = [[]] * len(self.seasons)
group_indices: list[list[int]] = [[] for _ in range(len(self.seasons))]

if self.drop_incomplete:
year = group.dt.year.data

for axis_index, seasgroup in enumerate(seasons_groups):
for season_tuple, code in zip(
seasgroup.inds, seasgroup.codes, strict=False
):
mask = np.isin(months, season_tuple)
codes_[axis_index, mask] = code
(indices,) = mask.nonzero()
group_indices[code] = indices.tolist()
if not self.drop_incomplete:
codes_[axis_index, mask] = code
(indices,) = mask.nonzero()
group_indices[code] = indices.tolist()
else:
season_str = self.seasons[code]
year_adjusted = _adjust_years_for_season(
year, months, season_tuple, season_str
)

# find unique years for this season
if not np.any(mask):
continue
unique_years = np.unique(year_adjusted[mask])

for yr in unique_years:
year_mask = year_adjusted == yr

# elements for this season in this year
year_season_mask = mask & year_mask

# check for completeness
present_months = np.unique(months[year_season_mask])
if len(present_months) == len(season_tuple):
codes_[axis_index, year_season_mask] = code
(indices,) = year_season_mask.nonzero()
group_indices[code].extend(indices.tolist())

if np.all(codes_ == -1):
raise ValueError(
Expand All @@ -792,8 +868,16 @@ def factorize(self, group: T_Group) -> EncodedGroups:
attrs=group.attrs,
name="season",
)
unique_coord = Variable("season", self.seasons, attrs=group.attrs)
full_index = pd.Index(self.seasons)

# Always filter coordinates to match actual data present
# This avoids dimension mismatches regardless of drop_incomplete setting
present_codes = np.unique(codes.data.ravel())
present_codes = present_codes[present_codes >= 0] # Remove -1 (missing data)
present_seasons = [self.seasons[code] for code in present_codes]

unique_coord = Variable("season", present_seasons, attrs=group.attrs)
full_index = pd.Index(present_seasons)

return EncodedGroups(
codes=codes,
group_indices=tuple(group_indices),
Expand All @@ -802,7 +886,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
)

def reset(self) -> Self:
return type(self)(self.seasons)
return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete)


@dataclass
Expand Down Expand Up @@ -872,10 +956,15 @@ def factorize(self, group: T_Group) -> EncodedGroups:
# offset years for seasons with December and January
for season_str, season_ind in zip(seasons, season_inds, strict=True):
season_label[month.isin(season_ind)] = season_str
if "DJ" in season_str:
after_dec = season_ind[season_str.index("D") + 1 :]
# important: this is assuming non-overlapping seasons
year[month.isin(after_dec)] -= 1

# Apply year adjustment for cross-year seasons
year_adjusted = year.copy()
for season_str, season_ind in zip(seasons, season_inds, strict=True):
if "D" in season_str and 12 in season_ind:
# Use helper function for year adjustment
year_adjusted[:] = _adjust_years_for_season(
year_adjusted.values, month.values, tuple(season_ind), season_str
)

# Allow users to skip one or more months?
# present_seasons is a mask that is True for months that are requested in the output
Expand All @@ -889,7 +978,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
"month": month[present_seasons],
},
index=pd.MultiIndex.from_arrays(
[year.data[present_seasons], season_label[present_seasons]],
[year_adjusted.data[present_seasons], season_label[present_seasons]],
names=["year", "season"],
),
)
Expand Down Expand Up @@ -928,7 +1017,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
[
datetime_class(year=y, month=m, day=1)
for y, m in itertools.product(
range(year[0].item(), year[-1].item() + 1),
range(year_adjusted[0].item(), year_adjusted[-1].item() + 1),
[s[0] for s in season_inds],
)
]
Expand All @@ -943,7 +1032,7 @@ def get_label(year, season):
unique_codes = np.arange(len(unique_coord))
valid_season_mask = season_label != ""
first_valid_season, last_valid_season = season_label[valid_season_mask][[0, -1]]
first_year, last_year = year.data[[0, -1]]
first_year, last_year = year_adjusted.data[[0, -1]]
if self.drop_incomplete:
if month.data[valid_season_mask][0] != season_tuples[first_valid_season][0]:
if "DJ" in first_valid_season:
Expand Down
138 changes: 138 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3604,6 +3604,144 @@ def test_season_resampler_groupby_identical(self):
gb = da.groupby(time=resampler).sum()
assert_identical(rs, gb)

@pytest.mark.parametrize("calendar", ["standard"])
def test_season_grouper_drop_incomplete_default_false(self, calendar):
"""Test that drop_incomplete=False is the default and includes partial seasons."""
# Create data that starts mid-winter (missing Dec 2000)
time = date_range("2001-01-01", "2001-12-31", freq="MS", calendar=calendar)
data = np.arange(len(time))
da = DataArray(data, dims="time", coords={"time": time})

# Default behavior should include incomplete seasons
result_default = da.groupby(
time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"])
).mean()
result_explicit = da.groupby(
time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False)
).mean()

assert_identical(result_default, result_explicit)

# Should include 4 seasons (including incomplete DJF with just Jan-Feb)
assert len(result_default) == 4
assert list(result_default.season.values) == ["DJF", "MAM", "JJA", "SON"]

@pytest.mark.parametrize("calendar", ["standard"])
def test_season_grouper_drop_incomplete_true(self, calendar):
"""Test that drop_incomplete=True excludes partial seasons."""
# Create data that starts mid-winter (missing Dec 2000) and ends mid-autumn (missing Nov-Dec 2002)
time = date_range("2001-01-01", "2002-10-31", freq="MS", calendar=calendar)
data = np.arange(len(time))
da = DataArray(data, dims="time", coords={"time": time})

# With drop_incomplete=True, should exclude incomplete seasons
result_drop = da.groupby(
time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True)
).mean()

# Should only include complete seasons
# 2001: DJF is incomplete (missing Dec 2000), MAM/JJA/SON are complete
# 2002: DJF/MAM/JJA are complete, SON is incomplete (missing Nov-Dec)
assert len(result_drop) <= 6 # At most 6 complete seasons

# Compare with default behavior
result_default = da.groupby(
time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False)
).mean()
assert len(result_drop) <= len(result_default)

@pytest.mark.parametrize("calendar", ["standard"])
def test_season_grouper_drop_incomplete_cross_year_seasons(self, calendar):
"""Test drop_incomplete with seasons that span calendar years like DJF."""
# Create 2 complete years of data
time = date_range("2001-01-01", "2002-12-31", freq="MS", calendar=calendar)
data = np.arange(len(time))
da = DataArray(data, dims="time", coords={"time": time})

# Test with DJF season (spans calendar year)
result_keep = da.groupby(
time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False)
).mean()
result_drop = da.groupby(
time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True)
).mean()

# With complete data, both should give same number of seasons
assert len(result_keep) == len(result_drop)

# Now test with incomplete data - start from Feb (missing Dec 2000 and Jan 2001)
time_incomplete = date_range(
"2001-02-01",
"2001-12-31",
freq="MS",
calendar=calendar, # Stop before next year
)
data_incomplete = np.arange(len(time_incomplete))
da_incomplete = DataArray(
data_incomplete, dims="time", coords={"time": time_incomplete}
)

result_keep_inc = da_incomplete.groupby(
time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False)
).mean()
result_drop_inc = da_incomplete.groupby(
time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True)
).mean()

# drop_incomplete should exclude the incomplete first DJF season
# Data starts in Feb 2001, so 2000-2001 DJF is incomplete (missing Dec 2000, Jan 2001)
assert len(result_drop_inc) < len(result_keep_inc)

@pytest.mark.parametrize("calendar", ["standard"])
def test_season_grouper_drop_incomplete_all_incomplete(self, calendar):
"""Test that drop_incomplete handles the case where all seasons are incomplete."""
# Create data with only January (incomplete for any multi-month season)
time = date_range("2001-01-01", "2001-01-31", freq="D", calendar=calendar)
data = np.arange(len(time))
da = DataArray(data, dims="time", coords={"time": time})

# Should raise error when all seasons are incomplete and drop_incomplete=True
with pytest.raises(ValueError, match="Failed to group data"):
da.groupby(
time=SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True)
).mean()

def test_season_grouper_reset_preserves_drop_incomplete(self):
"""Test that the reset method preserves the drop_incomplete setting."""
grouper1 = SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=True)
grouper2 = grouper1.reset()

assert grouper2.drop_incomplete == grouper1.drop_incomplete
assert grouper2.seasons == grouper1.seasons

grouper3 = SeasonGrouper(["DJF", "MAM", "JJA", "SON"], drop_incomplete=False)
grouper4 = grouper3.reset()

assert grouper4.drop_incomplete == grouper3.drop_incomplete
assert grouper4.seasons == grouper3.seasons

def test_adjust_years_for_season_helper(self):
"""Test the helper function _adjust_years_for_season."""
from xarray.groupers import _adjust_years_for_season

years = np.array([2001, 2001, 2001, 2002, 2002, 2002])
months = np.array([12, 1, 2, 12, 1, 2])

# Test DJF season (December, January, February)
adjusted = _adjust_years_for_season(years, months, (12, 1, 2), "DJF")
expected = np.array(
[2001, 2000, 2000, 2002, 2001, 2001]
) # Jan/Feb get previous year
np.testing.assert_array_equal(adjusted, expected)

# Test MAM season (no cross-year adjustment needed)
adjusted_mam = _adjust_years_for_season(years, months, (3, 4, 5), "MAM")
np.testing.assert_array_equal(adjusted_mam, years) # Should be unchanged

# Test single month season
adjusted_jan = _adjust_years_for_season(years, months, (1,), "J")
np.testing.assert_array_equal(adjusted_jan, years) # Should be unchanged


# TODO: Possible property tests to add to this module
# 1. lambda x: x
Expand Down
Loading