@@ -15,6 +15,8 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
15
15
This is useful when you want to assign a column based on multiple conditions.
16
16
Uses `Series.mask` to perform the assignment.
17
17
18
+ The returned Series will always have a new index (reset).
19
+
18
20
Parameters
19
21
----------
20
22
obj : Dataframe or Series on which the conditions will be applied.
@@ -28,8 +30,7 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
28
30
will be used to create the `Series` on which `Series.mask` will be called.
29
31
If this value is not already an array like (i.e. it is not of type `Series`,
30
32
`np.array` or `list`) it will be repeated `obj.shape[0]` times in order to
31
- create an array like object from it and then apply the `Series.mask`. In any
32
- case, the default series will be forced to take the index of `obj`.
33
+ create an array like object from it and then apply the `Series.mask`.
33
34
34
35
Returns
35
36
-------
@@ -104,7 +105,7 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
104
105
2 -1
105
106
Name: a, dtype: int64
106
107
107
- The index will always follow that of `obj` . For example:
108
+ The index is not maintained . For example:
108
109
>>> df = pd.DataFrame(
109
110
... dict(a=[1, 2, 3], b=[4, 5, 6]),
110
111
... index=['index 1', 'index 2', 'index 3']
@@ -121,9 +122,9 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
121
122
... df.b,
122
123
... default=0,
123
124
... )
124
- index 1 4
125
- index 2 0
126
- index 3 0
125
+ 0 4
126
+ 1 0
127
+ 2 0
127
128
dtype: int64
128
129
"""
129
130
len_args = len (args )
@@ -139,9 +140,9 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
139
140
140
141
# construct series on which we will apply `Series.mask`
141
142
if is_array_like (default ):
142
- series = pd .Series (default , index = obj . index )
143
+ series = pd .Series (default ). reset_index ( drop = True )
143
144
else :
144
- series = pd .Series ([default ] * obj .shape [0 ], index = obj . index )
145
+ series = pd .Series ([default ] * obj .shape [0 ])
145
146
146
147
for i in range (0 , len_args , 2 ):
147
148
# get conditions
@@ -153,6 +154,13 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
153
154
# get replacements
154
155
replacements = args [i + 1 ]
155
156
157
+ # if `conditions` or `replacements` are series, make sure to reset their index
158
+ if isinstance (conditions , pd .Series ):
159
+ conditions = conditions .reset_index (drop = True )
160
+
161
+ if isinstance (replacements , pd .Series ):
162
+ replacements = replacements .reset_index (drop = True )
163
+
156
164
# `Series.mask` call
157
165
series = series .mask (conditions , replacements )
158
166
0 commit comments