@@ -28,7 +28,8 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
28
28
will be used to create the `Series` on which `Series.mask` will be called.
29
29
If this value is not already an array like (i.e. it is not of type `Series`,
30
30
`np.array` or `list`) it will be repeated `obj.shape[0]` times in order to
31
- create an array like object from it and then apply the `Series.mask`.
31
+ create an array like object from it and then apply the `Series.mask`. In any
32
+ case, the default series will be forced to take the index of `obj`.
32
33
33
34
Returns
34
35
-------
@@ -102,6 +103,28 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
102
103
1 -1
103
104
2 -1
104
105
Name: a, dtype: int64
106
+
107
+ The index will always follow that of `obj`. For example:
108
+ >>> df = pd.DataFrame(
109
+ ... dict(a=[1, 2, 3], b=[4, 5, 6]),
110
+ ... index=['index 1', 'index 2', 'index 3']
111
+ ... )
112
+ >>> df
113
+ a b
114
+ index 1 1 4
115
+ index 2 2 5
116
+ index 3 3 6
117
+
118
+ >>> pd.case_when(
119
+ ... df,
120
+ ... lambda x: (x.a == 1) & (x.b == 4),
121
+ ... df.b,
122
+ ... default=0,
123
+ ... )
124
+ index 1 4
125
+ index 2 0
126
+ index 3 0
127
+ dtype: int64
105
128
"""
106
129
len_args = len (args )
107
130
@@ -116,9 +139,9 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
116
139
117
140
# construct series on which we will apply `Series.mask`
118
141
if is_array_like (default ):
119
- series = pd .Series (default )
142
+ series = pd .Series (default , index = obj . index )
120
143
else :
121
- series = pd .Series ([default ] * obj .shape [0 ])
144
+ series = pd .Series ([default ] * obj .shape [0 ], index = obj . index )
122
145
123
146
for i in range (0 , len_args , 2 ):
124
147
# get conditions
0 commit comments