Skip to content

Commit df1845d

Browse files
committed
fixup! Add case_when API * Used to support conditional assignment operation.
1 parent cdec601 commit df1845d

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

pandas/core/case_when.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
1515
This is useful when you want to assign a column based on multiple conditions.
1616
Uses `Series.mask` to perform the assignment.
1717
18+
The returned Series will always have a new index (reset).
19+
1820
Parameters
1921
----------
2022
obj : Dataframe or Series on which the conditions will be applied.
@@ -28,8 +30,7 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
2830
will be used to create the `Series` on which `Series.mask` will be called.
2931
If this value is not already an array like (i.e. it is not of type `Series`,
3032
`np.array` or `list`) it will be repeated `obj.shape[0]` times in order to
31-
create an array like object from it and then apply the `Series.mask`. In any
32-
case, the default series will be forced to take the index of `obj`.
33+
create an array like object from it and then apply the `Series.mask`.
3334
3435
Returns
3536
-------
@@ -104,7 +105,7 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
104105
2 -1
105106
Name: a, dtype: int64
106107
107-
The index will always follow that of `obj`. For example:
108+
The index is not maintained. For example:
108109
>>> df = pd.DataFrame(
109110
... dict(a=[1, 2, 3], b=[4, 5, 6]),
110111
... index=['index 1', 'index 2', 'index 3']
@@ -121,9 +122,9 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
121122
... df.b,
122123
... default=0,
123124
... )
124-
index 1 4
125-
index 2 0
126-
index 3 0
125+
0 4
126+
1 0
127+
2 0
127128
dtype: int64
128129
"""
129130
len_args = len(args)
@@ -139,9 +140,9 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
139140

140141
# construct series on which we will apply `Series.mask`
141142
if is_array_like(default):
142-
series = pd.Series(default, index=obj.index)
143+
series = pd.Series(default).reset_index(drop=True)
143144
else:
144-
series = pd.Series([default] * obj.shape[0], index=obj.index)
145+
series = pd.Series([default] * obj.shape[0])
145146

146147
for i in range(0, len_args, 2):
147148
# get conditions
@@ -153,6 +154,13 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
153154
# get replacements
154155
replacements = args[i + 1]
155156

157+
# if `conditions` or `replacements` are series, make sure to reset their index
158+
if isinstance(conditions, pd.Series):
159+
conditions = conditions.reset_index(drop=True)
160+
161+
if isinstance(replacements, pd.Series):
162+
replacements = replacements.reset_index(drop=True)
163+
156164
# `Series.mask` call
157165
series = series.mask(conditions, replacements)
158166

0 commit comments

Comments
 (0)