Skip to content

Commit 299927a

Browse files
mostly types, also exclude groupby
1 parent b7ea297 commit 299927a

File tree

6 files changed

+32
-11
lines changed

6 files changed

+32
-11
lines changed

pandas/core/dtypes/dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,7 @@ class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype):
951951
# "Dict[int, PandasExtensionDtype]", base class "PandasExtensionDtype"
952952
# defined the type as "Dict[str, PandasExtensionDtype]") [assignment]
953953
_cache_dtypes: dict[BaseOffset, PeriodDtype] = {} # type: ignore[assignment] # noqa: E501
954-
__hash__ = PeriodDtypeBase.__hash__
954+
__hash__ = PeriodDtypeBase.__hash__ # type: ignore[assignment]
955955
_freq: BaseOffset
956956

957957
def __new__(cls, freq):

pandas/core/frame.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5510,6 +5510,8 @@ def shift(
55105510
axis = self._get_axis_number(axis)
55115511

55125512
if is_list_like(periods):
5513+
# periods is not necessarily a list, but otherwise mypy complains.
5514+
periods = cast(list, periods)
55135515
if axis == 1:
55145516
raise ValueError(
55155517
"If `periods` contains multiple shifts, `axis` cannot be 1."
@@ -5518,18 +5520,20 @@ def shift(
55185520
raise ValueError("If `periods` is an iterable, it cannot be empty.")
55195521
from pandas.core.reshape.concat import concat
55205522

5521-
result = []
5523+
shifted_dataframes = []
55225524
for period in periods:
5523-
if not isinstance(period, int):
5525+
if not isinstance(int, period):
55245526
raise TypeError(
55255527
f"Periods must be integer, but {period} is {type(period)}."
55265528
)
5527-
result.append(
5529+
period = cast(int, period)
5530+
shifted_dataframes.append(
55285531
super()
55295532
.shift(periods=period, freq=freq, axis=axis, fill_value=fill_value)
55305533
.add_suffix(f"{suffix}_{period}" if suffix else f"_{period}")
55315534
)
5532-
return concat(result, axis=1) if result else self
5535+
return concat(shifted_dataframes, axis=1) if shifted_dataframes else self
5536+
periods = cast(int, periods)
55335537

55345538
if freq is not None and fill_value is not lib.no_default:
55355539
# GH#53832

pandas/core/generic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10522,7 +10522,7 @@ def mask(
1052210522
@doc(klass=_shared_doc_kwargs["klass"])
1052310523
def shift(
1052410524
self,
10525-
periods: int | Iterable = 1,
10525+
periods: int | Iterable[int] = 1,
1052610526
freq=None,
1052710527
axis: Axis = 0,
1052810528
fill_value: Hashable = lib.no_default,
@@ -10656,6 +10656,7 @@ def shift(
1065610656
return self.to_frame().shift(
1065710657
periods=periods, freq=freq, axis=axis, fill_value=fill_value
1065810658
)
10659+
periods = cast(int, periods)
1065910660

1066010661
if freq is None:
1066110662
# when freq is None, data is shifted, index is not

pandas/core/groupby/groupby.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4866,7 +4866,6 @@ def shift(
48664866
freq=None,
48674867
axis: Axis | lib.NoDefault = lib.no_default,
48684868
fill_value=None,
4869-
suffix: str | None = None,
48704869
):
48714870
"""
48724871
Shift each group by periods observations.
@@ -4875,8 +4874,9 @@ def shift(
48754874
48764875
Parameters
48774876
----------
4878-
periods : int, default 1
4879-
Number of periods to shift.
4877+
periods : int | Iterable[int], default 1
4878+
Number of periods to shift. If a list of values, shift each group by
4879+
each period.
48804880
freq : str, optional
48814881
Frequency string.
48824882
axis : axis to shift, default 0
@@ -4888,8 +4888,6 @@ def shift(
48884888
48894889
fill_value : optional
48904890
The scalar value to use for newly introduced missing values.
4891-
suffix : str, optional
4892-
An optional suffix to append when there are multiple periods.
48934891
48944892
Returns
48954893
-------
@@ -4938,6 +4936,11 @@ def shift(
49384936
catfish NaN NaN
49394937
goldfish 5.0 8.0
49404938
"""
4939+
if is_list_like(periods):
4940+
raise NotImplementedError(
4941+
"shift with multiple periods is not implemented yet for groupby."
4942+
)
4943+
49414944
if axis is not lib.no_default:
49424945
axis = self.obj._get_axis_number(axis)
49434946
self._deprecate_axis(axis, "shift")

pandas/tests/groupby/test_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def test_frame_consistency(groupby_func):
192192
exclude_expected = {"numeric_only"}
193193
elif groupby_func in ("quantile",):
194194
exclude_expected = {"method", "axis"}
195+
elif groupby_func in ("shift",):
196+
exclude_expected = {"suffix"}
195197

196198
# Ensure excluded arguments are actually in the signatures
197199
assert result & exclude_result == exclude_result
@@ -252,6 +254,8 @@ def test_series_consistency(request, groupby_func):
252254
exclude_expected = {"args", "kwargs"}
253255
elif groupby_func in ("quantile",):
254256
exclude_result = {"numeric_only"}
257+
elif groupby_func in ("shift",):
258+
exclude_expected = {"suffix"}
255259

256260
# Ensure excluded arguments are actually in the signatures
257261
assert result & exclude_result == exclude_result

pandas/tests/groupby/test_groupby_shift_diff.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,12 @@ def test_multindex_empty_shift_with_fill():
154154
shifted_with_fill = df.groupby(["a", "b"]).shift(1, fill_value=0)
155155
tm.assert_frame_equal(shifted, shifted_with_fill)
156156
tm.assert_index_equal(shifted.index, shifted_with_fill.index)
157+
158+
159+
def test_group_shift_with_multiple_periods():
160+
df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]})
161+
with pytest.raises(
162+
NotImplementedError,
163+
match=r"shift with multiple periods is not implemented yet for groupby.",
164+
):
165+
df.groupby("a")["b"].shift([1, 2])

0 commit comments

Comments
 (0)