Skip to content

Commit cdec601

Browse files
committed
fixup! Add case_when API * Used to support conditional assignment operation.
1 parent 3e6dc99 commit cdec601

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

pandas/core/case_when.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
2828
will be used to create the `Series` on which `Series.mask` will be called.
2929
If this value is not already an array like (i.e. it is not of type `Series`,
3030
`np.array` or `list`) it will be repeated `obj.shape[0]` times in order to
31-
create an array like object from it and then apply the `Series.mask`.
31+
create an array like object from it and then apply the `Series.mask`. In any
32+
case, the default series will be forced to take the index of `obj`.
3233
3334
Returns
3435
-------
@@ -102,6 +103,28 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
102103
1 -1
103104
2 -1
104105
Name: a, dtype: int64
106+
107+
The index will always follow that of `obj`. For example:
108+
>>> df = pd.DataFrame(
109+
... dict(a=[1, 2, 3], b=[4, 5, 6]),
110+
... index=['index 1', 'index 2', 'index 3']
111+
... )
112+
>>> df
113+
a b
114+
index 1 1 4
115+
index 2 2 5
116+
index 3 3 6
117+
118+
>>> pd.case_when(
119+
... df,
120+
... lambda x: (x.a == 1) & (x.b == 4),
121+
... df.b,
122+
... default=0,
123+
... )
124+
index 1 4
125+
index 2 0
126+
index 3 0
127+
dtype: int64
105128
"""
106129
len_args = len(args)
107130

@@ -116,9 +139,9 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
116139

117140
# construct series on which we will apply `Series.mask`
118141
if is_array_like(default):
119-
series = pd.Series(default)
142+
series = pd.Series(default, index=obj.index)
120143
else:
121-
series = pd.Series([default] * obj.shape[0])
144+
series = pd.Series([default] * obj.shape[0], index=obj.index)
122145

123146
for i in range(0, len_args, 2):
124147
# get conditions

0 commit comments

Comments
 (0)