Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Add dedicated no-null branch in arg_sort #16808

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
use super::*;

// Reduce monomorphisation.
fn sort_impl<T>(vals: &mut [(IdxSize, T)], options: SortOptions)
where
T: TotalOrd + Send + Sync,
{
sort_by_branch(
vals,
options.descending,
|a, b| a.1.tot_cmp(&b.1),
options.multithreaded,
);
}

pub(super) fn arg_sort<I, J, T>(
name: &str,
iters: I,
Expand All @@ -12,7 +25,6 @@ where
J: IntoIterator<Item = Option<T>>,
T: TotalOrd + Send + Sync,
{
let descending = options.descending;
let nulls_last = options.nulls_last;

let mut vals = Vec::with_capacity(len - null_count);
Expand All @@ -37,12 +49,7 @@ where
vals.extend(iter);
}

sort_by_branch(
vals.as_mut_slice(),
descending,
|a, b| a.1.tot_cmp(&b.1),
options.multithreaded,
);
sort_impl(vals.as_mut_slice(), options);

let iter = vals.into_iter().map(|(idx, _v)| idx);
let idx = if nulls_last {
Expand All @@ -60,3 +67,33 @@ where

ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
}

pub(super) fn arg_sort_no_nulls<I, J, T>(
name: &str,
iters: I,
options: SortOptions,
len: usize,
) -> IdxCa
where
I: IntoIterator<Item = J>,
J: IntoIterator<Item = T>,
T: TotalOrd + Send + Sync,
{
let mut vals = Vec::with_capacity(len);

let mut count: IdxSize = 0;
for arr_iter in iters {
vals.extend(arr_iter.into_iter().map(|v| {
let idx = count;
count += 1;
(idx, v)
}));
}

sort_impl(vals.as_mut_slice(), options);

let iter = vals.into_iter().map(|(idx, _v)| idx);
let idx: Vec<_> = iter.collect_trusted();

ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
}
139 changes: 69 additions & 70 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ where
}
}

#[inline]
fn sort_unstable_by_branch<T, C>(slice: &mut [T], descending: bool, cmp: C, parallel: bool)
where
T: Send,
Expand All @@ -62,6 +61,19 @@ where
}
}

// Reduce monomorphisation.
fn sort_impl_unstable<T>(vals: &mut [T], options: SortOptions)
where
T: TotalOrd + Send + Sync,
{
sort_unstable_by_branch(
vals,
options.descending,
TotalOrd::tot_cmp,
options.multithreaded,
);
}

macro_rules! sort_with_fast_path {
($ca:ident, $options:expr) => {{
if $ca.is_empty() {
Expand Down Expand Up @@ -103,12 +115,7 @@ where
if ca.null_count() == 0 {
let mut vals = ca.to_vec_null_aware().left().unwrap();

sort_unstable_by_branch(
vals.as_mut_slice(),
options.descending,
TotalOrd::tot_cmp,
options.multithreaded,
);
sort_impl_unstable(vals.as_mut_slice(), options);

let mut ca = ChunkedArray::from_vec(ca.name(), vals);
let s = if options.descending {
Expand Down Expand Up @@ -139,12 +146,7 @@ where
&mut vals[null_count..]
};

sort_unstable_by_branch(
mut_slice,
options.descending,
TotalOrd::tot_cmp,
options.multithreaded,
);
sort_impl_unstable(mut_slice, options);

let mut validity = MutableBitmap::with_capacity(len);
if options.nulls_last {
Expand Down Expand Up @@ -176,31 +178,11 @@ fn arg_sort_numeric<T>(ca: &ChunkedArray<T>, options: SortOptions) -> IdxCa
where
T: PolarsNumericType,
{
let descending = options.descending;
if ca.null_count() == 0 {
let mut vals = Vec::with_capacity(ca.len());
let mut count: IdxSize = 0;
ca.downcast_iter().for_each(|arr| {
let values = arr.values();
let iter = values.iter().map(|&v| {
let i = count;
count += 1;
(i, v)
});
vals.extend_trusted_len(iter);
});

sort_by_branch(
vals.as_mut_slice(),
descending,
|a, b| a.1.tot_cmp(&b.1),
options.multithreaded,
);

let out: NoNull<IdxCa> = vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
let mut out = out.into_inner();
out.rename(ca.name());
out
let iter = ca
.downcast_iter()
.map(|arr| arr.values().as_slice().iter().copied());
arg_sort::arg_sort_no_nulls(ca.name(), iter, options, ca.len())
} else {
let iter = ca
.downcast_iter()
Expand Down Expand Up @@ -337,12 +319,7 @@ impl ChunkSort<BinaryType> for BinaryChunked {
for arr in self.downcast_iter() {
v.extend(arr.non_null_values_iter());
}
sort_unstable_by_branch(
v.as_mut_slice(),
options.descending,
Ord::cmp,
options.multithreaded,
);
sort_impl_unstable(v.as_mut_slice(), options);

let len = self.len();
let null_count = self.null_count();
Expand Down Expand Up @@ -380,13 +357,22 @@ impl ChunkSort<BinaryType> for BinaryChunked {
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
if self.null_count() == 0 {
arg_sort::arg_sort_no_nulls(
self.name(),
self.downcast_iter().map(|arr| arr.values_iter()),
options,
self.len(),
)
} else {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
}
}

fn arg_sort_multiple(
Expand Down Expand Up @@ -420,12 +406,7 @@ impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
v.extend(arr.non_null_values_iter());
}

sort_unstable_by_branch(
v.as_mut_slice(),
options.descending,
Ord::cmp,
options.multithreaded,
);
sort_impl_unstable(v.as_mut_slice(), options);

let mut values = Vec::<u8>::with_capacity(self.get_values_size());
let mut offsets = Vec::<i64>::with_capacity(self.len() + 1);
Expand Down Expand Up @@ -511,13 +492,22 @@ impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
if self.null_count() == 0 {
arg_sort::arg_sort_no_nulls(
self.name(),
self.downcast_iter().map(|arr| arr.values_iter()),
options,
self.len(),
)
} else {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
}
}

/// # Panics
Expand Down Expand Up @@ -609,13 +599,22 @@ impl ChunkSort<BooleanType> for BooleanChunked {
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
if self.null_count() == 0 {
arg_sort::arg_sort_no_nulls(
self.name(),
self.downcast_iter().map(|arr| arr.values_iter()),
options,
self.len(),
)
} else {
arg_sort::arg_sort(
self.name(),
self.downcast_iter().map(|arr| arr.iter()),
options,
self.null_count(),
self.len(),
)
}
}
fn arg_sort_multiple(
&self,
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,3 +1000,15 @@ def test_sort_nan_1942() -> None:
end = time.time()

assert (end - start) < 1.0


def test_sort_chunked_no_nulls() -> None:
df = pl.DataFrame({"values": [3.0, 2.0]})
df = pl.concat([df, df], rechunk=False)

assert df.with_columns(pl.col("values").arg_sort())["values"].to_list() == [
1,
3,
0,
2,
]