-
-
Notifications
You must be signed in to change notification settings - Fork 18.8k
PERF: Cythonize Groupby Rank #19481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PERF: Cythonize Groupby Rank #19481
Changes from 1 commit
396f1b6
c2c2177
529503f
c7faa3b
baeb192
07c8e0f
2ba6643
4e54aa5
428d32c
902ef3c
ecd4b51
e17433d
b0ea557
e15b4b2
04eb4f1
7a4602d
7be3bf3
ca28350
913ce94
4755941
d4a6662
56e7974
9d7c3e6
178654d
f6ae88a
caacef2
a315a92
a6ca485
fd29d70
b9e4719
613384c
94a2749
3ee99c0
b430635
aa4578d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,29 +140,37 @@ def group_rank_object(ndarray[float64_t, ndim=2] out, | |
int tiebreak | ||
Py_ssize_t i, j, N, K | ||
int64_t val_start=0, grp_start=0, dups=0, sum_ranks=0, vals_seen=1 | ||
int64_t grp_na_count=0 | ||
ndarray[int64_t] _as | ||
bint pct, ascending | ||
ndarray[object] _values | ||
bint pct, ascending, keep_na | ||
|
||
tiebreak = tiebreakers[kwargs['ties_method']] | ||
ascending = kwargs['ascending'] | ||
pct = kwargs['pct'] | ||
keep_na = kwargs['na_option'] == 'keep' | ||
N, K = (<object> values).shape | ||
|
||
vals = np.array(values[:, 0], copy=True) | ||
mask = missing.isnaobj(vals) | ||
_values = np.array(values[:, 0], copy=True) | ||
mask = missing.isnaobj(_values) | ||
|
||
if ascending ^ (kwargs['na_option'] == 'top'): | ||
nan_value = np.inf | ||
order = (_values, mask, labels) | ||
else: | ||
nan_value = -np.inf | ||
order = (_values, ~mask, labels) | ||
np.putmask(_values, mask, nan_value) | ||
try: | ||
_as = np.lexsort((vals, labels)) | ||
_as = np.lexsort(order) | ||
except TypeError: | ||
# lexsort fails when missing data and objects are mixed | ||
# fallback to argsort | ||
order = (vals, mask, labels) | ||
_values = np.asarray(list(zip(order[0], order[1], order[2])), | ||
dtype=[('values', 'O'), ('mask', '?'), | ||
('labels', 'i8')]) | ||
_as = np.argsort(_values, kind='mergesort', order=('labels', | ||
'mask', 'values')) | ||
_arr = np.asarray(list(zip(order[0], order[1], order[2])), | ||
dtype=[('values', 'O'), ('mask', '?'), | ||
('labels', 'i8')]) | ||
_as = np.argsort(_arr, kind='mergesort', order=('labels', | ||
'mask', 'values')) | ||
|
||
if not ascending: | ||
_as = _as[::-1] | ||
|
@@ -171,7 +179,8 @@ def group_rank_object(ndarray[float64_t, ndim=2] out, | |
dups += 1 | ||
sum_ranks += i - grp_start + 1 | ||
|
||
if keep_na and mask[_as[i]]: | ||
if keep_na and (values[_as[i], 0] != values[_as[i], 0]): | ||
grp_na_count += 1 | ||
out[_as[i], 0] = np.nan | ||
else: | ||
if tiebreak == TIEBREAK_AVERAGE: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general this looping mechanism isn't very efficient because it overwrites "duplicate" values continually in a loop. Given the benchmarks were still significantly faster I left as is and was planning to open up a separate change to optimize further, but could review as part of this if you feel this loop mechanism is not acceptable |
||
|
@@ -204,8 +213,11 @@ def group_rank_object(ndarray[float64_t, ndim=2] out, | |
if i == N - 1 or labels[_as[i]] != labels[_as[i+1]]: | ||
if pct: | ||
for j in range(grp_start, i + 1): | ||
out[_as[j], 0] = out[_as[j], 0] / (i - grp_start + 1) | ||
out[_as[j], 0] = out[_as[j], 0] / (i - grp_start + 1 | ||
- grp_na_count) | ||
grp_na_count = 0 | ||
grp_start = i + 1 | ||
vals_seen = 1 | ||
|
||
|
||
cdef inline float64_t median_linear(float64_t* a, int n) nogil: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to my limited understanding I am not using the K value that gets extracted here, as I couldn't figure out under what circumstance K was ever not equal to 0. Can you advise how that works or what to look at to help my comprehension?