Skip to content

Commit cabd28c

Browse files
change how futurewarning is handled in the test
1 parent b78d7d6 commit cabd28c

File tree

4 files changed

+39
-14
lines changed

4 files changed

+39
-14
lines changed

pandas/core/frame.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5526,6 +5526,7 @@ def shift(
55265526
raise TypeError(
55275527
f"Periods must be integer, but {period} is {type(period)}."
55285528
)
5529+
print(super())
55295530
shifted_dataframes.append(
55305531
super()
55315532
.shift(periods=period, freq=freq, axis=axis, fill_value=fill_value)

pandas/core/groupby/groupby.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4866,6 +4866,7 @@ def shift(
48664866
freq=None,
48674867
axis: Axis | lib.NoDefault = lib.no_default,
48684868
fill_value=None,
4869+
suffix: str | None = None,
48694870
):
48704871
"""
48714872
Shift each group by periods observations.
@@ -4936,17 +4937,38 @@ def shift(
49364937
catfish NaN NaN
49374938
goldfish 5.0 8.0
49384939
"""
4939-
if is_list_like(periods):
4940-
raise NotImplementedError(
4941-
"shift with multiple periods is not implemented yet for groupby."
4942-
)
4943-
49444940
if axis is not lib.no_default:
49454941
axis = self.obj._get_axis_number(axis)
49464942
self._deprecate_axis(axis, "shift")
49474943
else:
49484944
axis = 0
49494945

4946+
if is_list_like(periods):
4947+
# periods is not necessarily a list, but otherwise mypy complains.
4948+
periods = cast(list, periods)
4949+
if axis == 1:
4950+
raise ValueError(
4951+
"If `periods` contains multiple shifts, `axis` cannot be 1."
4952+
)
4953+
if len(periods) == 0:
4954+
raise ValueError("If `periods` is an iterable, it cannot be empty.")
4955+
from pandas.core.reshape.concat import concat
4956+
4957+
shifted_dataframes = []
4958+
for period in periods:
4959+
if not isinstance(period, int):
4960+
raise TypeError(
4961+
f"Periods must be integer, but {period} is {type(period)}."
4962+
)
4963+
shifted_dataframes.append(
4964+
DataFrame(
4965+
self.shift(
4966+
periods=period, freq=freq, axis=axis, fill_value=fill_value
4967+
)
4968+
).add_suffix(f"{suffix}_{period}" if suffix else f"_{period}")
4969+
)
4970+
return concat(shifted_dataframes, axis=1) if shifted_dataframes else self
4971+
49504972
if freq is not None or axis != 0:
49514973
f = lambda x: x.shift(periods, freq, axis, fill_value)
49524974
return self._python_apply_general(f, self._selected_obj, is_transform=True)

pandas/tests/groupby/test_api.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,6 @@ def test_frame_consistency(groupby_func):
192192
exclude_expected = {"numeric_only"}
193193
elif groupby_func in ("quantile",):
194194
exclude_expected = {"method", "axis"}
195-
elif groupby_func in ("shift",):
196-
exclude_expected = {"suffix"}
197195

198196
# Ensure excluded arguments are actually in the signatures
199197
assert result & exclude_result == exclude_result
@@ -254,8 +252,6 @@ def test_series_consistency(request, groupby_func):
254252
exclude_expected = {"args", "kwargs"}
255253
elif groupby_func in ("quantile",):
256254
exclude_result = {"numeric_only"}
257-
elif groupby_func in ("shift",):
258-
exclude_expected = {"suffix"}
259255

260256
# Ensure excluded arguments are actually in the signatures
261257
assert result & exclude_result == exclude_result

pandas/tests/groupby/test_groupby_shift_diff.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,16 @@ def test_multindex_empty_shift_with_fill():
156156
tm.assert_index_equal(shifted.index, shifted_with_fill.index)
157157

158158

159+
@pytest.mark.filterwarnings("ignore::FutureWarning")
159160
def test_group_shift_with_multiple_periods():
160161
df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]})
161-
with pytest.raises(
162-
NotImplementedError,
163-
match=r"shift with multiple periods is not implemented yet for groupby.",
164-
):
165-
df.groupby("a")["b"].shift([1, 2])
162+
163+
shifted_df = df.groupby("b")[["a"]].shift([0, 1])
164+
expected_df = DataFrame(
165+
{"a_0": [1, 2, 3, 3, 2], "a_1": [np.nan, 1.0, np.nan, 3.0, 2.0]}
166+
)
167+
tm.assert_frame_equal(shifted_df, expected_df)
168+
169+
# series
170+
shifted_series = df.groupby("b")["a"].shift([0, 1])
171+
tm.assert_frame_equal(shifted_series, expected_df)

0 commit comments

Comments
 (0)