Skip to content

Commit 8d085f3

Browse files
address comments
1 parent 6f3ec9b commit 8d085f3

File tree

4 files changed

+79
-42
lines changed

4 files changed

+79
-42
lines changed

pandas/core/frame.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5507,6 +5507,13 @@ def shift(
55075507
fill_value: Hashable = lib.no_default,
55085508
suffix: str | None = None,
55095509
) -> DataFrame:
5510+
if freq is not None and fill_value is not lib.no_default:
5511+
# GH#53832
5512+
raise ValueError(
5513+
"Cannot pass both 'freq' and 'fill_value' to "
5514+
f"{type(self).__name__}.shift"
5515+
)
5516+
55105517
axis = self._get_axis_number(axis)
55115518

55125519
if is_list_like(periods):
@@ -5534,13 +5541,6 @@ def shift(
55345541
return concat(shifted_dataframes, axis=1) if shifted_dataframes else self
55355542
periods = cast(int, periods)
55365543

5537-
if freq is not None and fill_value is not lib.no_default:
5538-
# GH#53832
5539-
raise ValueError(
5540-
"Cannot pass both 'freq' and 'fill_value' to "
5541-
f"{type(self).__name__}.shift"
5542-
)
5543-
55445544
ncols = len(self.columns)
55455545
arrays = self._mgr.arrays
55465546
if axis == 1 and periods != 0 and ncols > 0 and freq is None:

pandas/core/groupby/groupby.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4949,8 +4949,6 @@ def shift(
49494949
axis = 0
49504950

49514951
if is_list_like(periods):
4952-
# periods is not necessarily a list, but otherwise mypy complains.
4953-
periods = cast(list, periods)
49544952
if axis == 1:
49554953
raise ValueError(
49564954
"If `periods` contains multiple shifts, `axis` cannot be 1."
@@ -4959,39 +4957,50 @@ def shift(
49594957
raise ValueError("If `periods` is an iterable, it cannot be empty.")
49604958
from pandas.core.reshape.concat import concat
49614959

4962-
shifted_dataframes = []
4963-
for period in periods:
4964-
if not isinstance(period, int):
4965-
raise TypeError(
4966-
f"Periods must be integer, but {period} is {type(period)}."
4967-
)
4968-
shifted_dataframes.append(
4969-
DataFrame(
4970-
self.shift(
4971-
periods=period, freq=freq, axis=axis, fill_value=fill_value
4972-
)
4973-
).add_suffix(f"{suffix}_{period}" if suffix else f"_{period}")
4960+
add_suffix = True
4961+
else:
4962+
periods = [periods]
4963+
add_suffix = False
4964+
4965+
shifted_dataframes = []
4966+
for period in periods:
4967+
if not isinstance(period, int):
4968+
raise TypeError(
4969+
f"Periods must be integer, but {period} is {type(period)}."
4970+
)
4971+
if freq is not None or axis != 0:
4972+
f = lambda x: x.shift(period, freq, axis, fill_value)
4973+
shifted = self._python_apply_general(
4974+
f, self._selected_obj, is_transform=True
49744975
)
4975-
return concat(shifted_dataframes, axis=1) if shifted_dataframes else self
4976-
periods = cast(int, periods)
49774976

4978-
if freq is not None or axis != 0:
4979-
f = lambda x: x.shift(periods, freq, axis, fill_value)
4980-
return self._python_apply_general(f, self._selected_obj, is_transform=True)
4977+
else:
4978+
ids, _, ngroups = self.grouper.group_info
4979+
res_indexer = np.zeros(len(ids), dtype=np.int64)
49814980

4982-
ids, _, ngroups = self.grouper.group_info
4983-
res_indexer = np.zeros(len(ids), dtype=np.int64)
4981+
libgroupby.group_shift_indexer(res_indexer, ids, ngroups, period)
49844982

4985-
libgroupby.group_shift_indexer(res_indexer, ids, ngroups, periods)
4983+
obj = self._obj_with_exclusions
49864984

4987-
obj = self._obj_with_exclusions
4985+
shifted = obj._reindex_with_indexers(
4986+
{self.axis: (obj.axes[self.axis], res_indexer)},
4987+
fill_value=fill_value,
4988+
allow_dups=True,
4989+
)
49884990

4989-
res = obj._reindex_with_indexers(
4990-
{self.axis: (obj.axes[self.axis], res_indexer)},
4991-
fill_value=fill_value,
4992-
allow_dups=True,
4991+
if add_suffix:
4992+
if len(shifted.shape) == 1:
4993+
shifted = shifted.to_frame()
4994+
shifted = shifted.add_suffix(
4995+
f"{suffix}_{period}" if suffix else f"_{period}"
4996+
)
4997+
shifted_dataframes.append(shifted)
4998+
4999+
return (
5000+
shifted_dataframes[0]
5001+
if len(shifted_dataframes) == 1
5002+
else concat(shifted_dataframes, axis=1)
49935003
)
4994-
return res
49955004

49965005
@final
49975006
@Substitution(name="groupby")

pandas/tests/frame/methods/test_shift.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def test_shift_axis1_many_periods(self):
657657
shifted2 = df.shift(-6, axis=1, fill_value=None)
658658
tm.assert_frame_equal(shifted2, expected)
659659

660-
def test_shift_with_iterable(self):
660+
def test_shift_with_iterable_basic_functionality(self):
661661
# GH#44424
662662
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
663663
shifts = [0, 1, 2]
@@ -677,19 +677,47 @@ def test_shift_with_iterable(self):
677677
)
678678
tm.assert_frame_equal(expected, shifted)
679679

680-
# test pd.Series
680+
def test_shift_with_iterable_series(self):
681+
data = {"a": [1, 2, 3]}
682+
shifts = [0, 1, 2]
683+
684+
df = DataFrame(data)
681685
s: Series = df["a"]
682-
df_one_column: DataFrame = df[["a"]]
683-
tm.assert_frame_equal(s.shift(shifts), df_one_column.shift(shifts))
686+
tm.assert_frame_equal(s.shift(shifts), df.shift(shifts))
687+
688+
def test_shift_with_iterable_freq_and_fill_value(self):
689+
df = DataFrame(
690+
np.random.randn(5), index=date_range("1/1/2000", periods=5, freq="H")
691+
)
692+
693+
tm.assert_frame_equal(
694+
# rename because shift with an iterable leads to str column names
695+
df.shift([1], fill_value=1).rename(columns=lambda x: int(x[0])),
696+
df.shift(1, fill_value=1),
697+
)
698+
699+
tm.assert_frame_equal(
700+
df.shift([1], freq="H").rename(columns=lambda x: int(x[0])),
701+
df.shift(1, freq="H"),
702+
)
703+
704+
msg = r"Cannot pass both 'freq' and 'fill_value' to.*"
705+
with pytest.raises(ValueError, match=msg):
706+
df.shift([1, 2], fill_value=1, freq="H")
707+
708+
def test_shift_with_iterable_check_other_arguments(self):
709+
data = {"a": [1, 2], "b": [4, 5]}
710+
shifts = [0, 1]
711+
df = DataFrame(data)
684712

685713
# test suffix
686714
columns = df[["a"]].shift(shifts, suffix="_suffix").columns
687-
assert columns.tolist() == ["a_suffix_0", "a_suffix_1", "a_suffix_2"]
715+
assert columns.tolist() == ["a_suffix_0", "a_suffix_1"]
688716

689717
# check bad inputs when doing multiple shifts
690718
msg = "If `periods` contains multiple shifts, `axis` cannot be 1."
691719
with pytest.raises(ValueError, match=msg):
692-
df.shift([1, 2], axis=1)
720+
df.shift(shifts, axis=1)
693721

694722
msg = "Periods must be integer, but s is <class 'str'>."
695723
with pytest.raises(TypeError, match=msg):

pandas/tests/groupby/test_groupby_shift_diff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_multindex_empty_shift_with_fill():
156156
tm.assert_index_equal(shifted.index, shifted_with_fill.index)
157157

158158

159-
@pytest.mark.filterwarnings("ignore::FutureWarning")
159+
@pytest.mark.filterwarnings("ignore:The 'axis' keyword in")
160160
def test_group_shift_with_multiple_periods():
161161
df = DataFrame({"a": [1, 2, 3, 3, 2], "b": [True, True, False, False, True]})
162162

0 commit comments

Comments
 (0)