From 474ac3478f8117f6bdc4992fae21f57c3bfdd86d Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 22 Mar 2024 09:15:10 +0100 Subject: [PATCH] fix: properly support nulls_last + descending (#15212) --- .../src/chunked_array/ops/compare_inner.rs | 68 ++++++++++++++++--- .../polars-core/src/chunked_array/ops/mod.rs | 1 + .../ops/sort/arg_sort_multiple.rs | 11 ++- .../src/chunked_array/ops/sort/mod.rs | 6 +- crates/polars-core/src/frame/mod.rs | 1 + .../src/physical_plan/expressions/sortby.rs | 3 + crates/polars-row/src/decode.rs | 2 - py-polars/tests/unit/operations/test_sort.py | 22 ++++++ .../unit/streaming/test_streaming_sort.py | 26 +++++++ 9 files changed, 125 insertions(+), 15 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/compare_inner.rs b/crates/polars-core/src/chunked_array/ops/compare_inner.rs index 02981d585144..8c1f52e54dd4 100644 --- a/crates/polars-core/src/chunked_array/ops/compare_inner.rs +++ b/crates/polars-core/src/chunked_array/ops/compare_inner.rs @@ -7,7 +7,47 @@ use crate::prelude::*; use crate::series::implementations::null::NullChunked; #[repr(transparent)] -struct NonNull(T); +#[derive(Copy, Clone)] +pub struct NonNull(pub T); + +impl TotalEq for NonNull { + fn tot_eq(&self, other: &Self) -> bool { + self.0.tot_eq(&other.0) + } +} + +pub trait NullOrderCmp { + fn null_order_cmp(&self, other: &Self, nulls_last: bool) -> Ordering; +} + +impl NullOrderCmp for Option { + fn null_order_cmp(&self, other: &Self, nulls_last: bool) -> Ordering { + match (self, other) { + (None, None) => Ordering::Equal, + (None, Some(_)) => { + if nulls_last { + Ordering::Greater + } else { + Ordering::Less + } + }, + (Some(_), None) => { + if nulls_last { + Ordering::Less + } else { + Ordering::Greater + } + }, + (Some(l), Some(r)) => l.tot_cmp(r), + } + } +} + +impl NullOrderCmp for NonNull { + fn null_order_cmp(&self, other: &Self, _nulls_last: bool) -> Ordering { + self.0.tot_cmp(&other.0) + } +} trait GetInner { type Item; @@ -29,16 +69,16 @@ impl<'a, T: StaticArray> GetInner for &'a T { } impl<'a, T: PolarsDataType> GetInner for NonNull<&'a ChunkedArray> { - type Item = T::Physical<'a>; + type Item = NonNull>; unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { - self.0.value_unchecked(idx) + NonNull(self.0.value_unchecked(idx)) } } impl<'a, T: StaticArray> GetInner for NonNull<&'a T> { - type Item = T::ValueT<'a>; + type Item = NonNull>; unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { - self.0.value_unchecked(idx) + NonNull(self.0.value_unchecked(idx)) } } @@ -51,7 +91,12 @@ pub trait TotalEqInner: Send + Sync { pub trait TotalOrdInner: Send + Sync { /// # Safety /// Does not do any bound checks. - unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering; + unsafe fn cmp_element_unchecked( + &self, + idx_a: usize, + idx_b: usize, + nulls_last: bool, + ) -> Ordering; } impl TotalEqInner for T @@ -102,13 +147,18 @@ where impl TotalOrdInner for T where T: GetInner + Send + Sync, - T::Item: TotalOrd, + T::Item: NullOrderCmp, { #[inline] - unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering { + unsafe fn cmp_element_unchecked( + &self, + idx_a: usize, + idx_b: usize, + nulls_last: bool, + ) -> Ordering { let a = self.get_unchecked(idx_a); let b = self.get_unchecked(idx_b); - a.tot_cmp(&b) + a.null_order_cmp(&b, nulls_last) } } diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index f8810149927d..57c0d31db405 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -375,6 +375,7 @@ pub struct SortOptions { pub struct SortMultipleOptions { pub other: Vec, pub descending: Vec, + pub nulls_last: bool, pub multithreaded: bool, } diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index c1e2fe379155..44a67af75294 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -1,3 +1,4 @@ +use compare_inner::NullOrderCmp; use polars_row::{convert_columns, RowsEncoded, SortField}; use polars_utils::iter::EnumerateIdxTrait; @@ -21,7 +22,7 @@ pub(crate) fn args_validate( Ok(()) } -pub(crate) fn arg_sort_multiple_impl( +pub(crate) fn arg_sort_multiple_impl( mut vals: Vec<(IdxSize, T)>, options: &SortMultipleOptions, ) -> PolarsResult { @@ -36,7 +37,12 @@ pub(crate) fn arg_sort_multiple_impl( let first_descending = descending[0]; POOL.install(|| { vals.par_sort_by(|tpl_a, tpl_b| { - match (first_descending, tpl_a.1.tot_cmp(&tpl_b.1)) { + match ( + first_descending, + tpl_a + .1 + .null_order_cmp(&tpl_b.1, options.nulls_last ^ first_descending), + ) { // if ordering is equal, we check the other arrays until we find a non-equal ordering // if we have exhausted all arrays, we keep the equal ordering. (_, Ordering::Equal) => { @@ -46,6 +52,7 @@ pub(crate) fn arg_sort_multiple_impl( ordering_other_columns( &compare_inner, descending.get_unchecked(1..), + options.nulls_last, idx_a, idx_b, ) diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index 9f7de82edc87..1610ea6b2fb4 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -10,6 +10,7 @@ pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt; use arrow::bitmap::MutableBitmap; use arrow::buffer::Buffer; use arrow::legacy::trusted_len::TrustedLenPush; +use compare_inner::NonNull; use rayon::prelude::*; pub use slice::*; @@ -221,7 +222,7 @@ fn arg_sort_multiple_numeric( vals.extend_trusted_len(arr.values().as_slice().iter().map(|v| { let i = count; count += 1; - (i, *v) + (i, NonNull(*v)) })) } arg_sort_multiple_impl(vals, options) @@ -269,13 +270,14 @@ where fn ordering_other_columns<'a>( compare_inner: &'a [Box], descending: &[bool], + nulls_last: bool, idx_a: usize, idx_b: usize, ) -> Ordering { for (cmp, descending) in compare_inner.iter().zip(descending) { // SAFETY: // indices are in bounds - let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b) }; + let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, nulls_last ^ descending) }; match (ordering, descending) { (Ordering::Equal, _) => continue, (_, true) => return ordering.reverse(), diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index f138dedffe1a..19db5298b088 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1860,6 +1860,7 @@ impl DataFrame { let options = SortMultipleOptions { other, descending, + nulls_last, multithreaded: parallel, }; first.arg_sort_multiple(&options)? diff --git a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs index d213af7631ed..e1b7e703a310 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs @@ -135,6 +135,7 @@ fn sort_by_groups_multiple_by( let options = SortMultipleOptions { other: groups[1..].to_vec(), descending: descending.to_owned(), + nulls_last: false, multithreaded: false, }; @@ -150,6 +151,7 @@ fn sort_by_groups_multiple_by( let options = SortMultipleOptions { other: groups[1..].to_vec(), descending: descending.to_owned(), + nulls_last: false, multithreaded: false, }; let sorted_idx = groups[0].arg_sort_multiple(&options).unwrap(); @@ -197,6 +199,7 @@ impl PhysicalExpr for SortByExpr { let options = SortMultipleOptions { other: s_sort_by[1..].to_vec(), descending, + nulls_last: false, multithreaded: true, }; s_sort_by[0].arg_sort_multiple(&options) diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index 5d9e514b9bbb..246ac976fc10 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -39,8 +39,6 @@ pub unsafe fn decode_rows( } unsafe fn decode(rows: &mut [&[u8]], field: &SortField, data_type: &ArrowDataType) -> ArrayRef { - // not yet supported for fixed types - assert!(!field.nulls_last, "not yet supported"); match data_type { ArrowDataType::Null => NullArray::new(ArrowDataType::Null, rows.len()).to_boxed(), ArrowDataType::Boolean => decode_bool(rows, field).to_boxed(), diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index aef3fcd8a5d9..64659fe274f8 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -982,3 +982,25 @@ def test_sorted_flag_concat_unit(unit_descending: bool) -> None: out = pl.concat((b, a)) assert out.to_list() == [3, 2, 1, None] assert out.flags["SORTED_DESC"] + + +@pytest.mark.parametrize("descending", [True, False]) +@pytest.mark.parametrize("nulls_last", [True, False]) +def test_sort_descending_nulls_last(descending: bool, nulls_last: bool) -> None: + df = pl.DataFrame({"x": [1, 3, None, 2, None], "y": [1, 3, 0, 2, 0]}) + + null_sentinel = 100 if descending ^ nulls_last else -100 + ref_x = [1, 3, None, 2, None] + ref_x.sort(key=lambda k: null_sentinel if k is None else k, reverse=descending) + ref_y = [1, 3, 0, 2, 0] + ref_y.sort(key=lambda k: null_sentinel if k == 0 else k, reverse=descending) + + assert_frame_equal( + df.sort("x", descending=descending, nulls_last=nulls_last), + pl.DataFrame({"x": ref_x, "y": ref_y}), + ) + + assert_frame_equal( + df.sort(["x", "y"], descending=descending, nulls_last=nulls_last), + pl.DataFrame({"x": ref_x, "y": ref_y}), + ) diff --git a/py-polars/tests/unit/streaming/test_streaming_sort.py b/py-polars/tests/unit/streaming/test_streaming_sort.py index e30859b74d3c..cff6c4c470e0 100644 --- a/py-polars/tests/unit/streaming/test_streaming_sort.py +++ b/py-polars/tests/unit/streaming/test_streaming_sort.py @@ -265,3 +265,29 @@ def test_nulls_last_streaming_sort() -> None: assert pl.LazyFrame({"x": [1, None]}).sort("x", nulls_last=True).collect( streaming=True ).to_dict(as_series=False) == {"x": [1, None]} + + +@pytest.mark.parametrize("descending", [True, False]) +@pytest.mark.parametrize("nulls_last", [True, False]) +def test_sort_descending_nulls_last(descending: bool, nulls_last: bool) -> None: + df = pl.DataFrame({"x": [1, 3, None, 2, None], "y": [1, 3, 0, 2, 0]}) + + null_sentinel = 100 if descending ^ nulls_last else -100 + ref_x = [1, 3, None, 2, None] + ref_x.sort(key=lambda k: null_sentinel if k is None else k, reverse=descending) + ref_y = [1, 3, 0, 2, 0] + ref_y.sort(key=lambda k: null_sentinel if k == 0 else k, reverse=descending) + + assert_frame_equal( + df.lazy() + .sort("x", descending=descending, nulls_last=nulls_last) + .collect(streaming=True), + pl.DataFrame({"x": ref_x, "y": ref_y}), + ) + + assert_frame_equal( + df.lazy() + .sort(["x", "y"], descending=descending, nulls_last=nulls_last) + .collect(streaming=True), + pl.DataFrame({"x": ref_x, "y": ref_y}), + )