Skip to content

Commit

Permalink
fix: properly support nulls_last + descending (#15212)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Mar 22, 2024
1 parent e035226 commit 474ac34
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 15 deletions.
68 changes: 59 additions & 9 deletions crates/polars-core/src/chunked_array/ops/compare_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,47 @@ use crate::prelude::*;
use crate::series::implementations::null::NullChunked;

#[repr(transparent)]
struct NonNull<T>(T);
#[derive(Copy, Clone)]
pub struct NonNull<T>(pub T);

impl<T: TotalEq> TotalEq for NonNull<T> {
fn tot_eq(&self, other: &Self) -> bool {
self.0.tot_eq(&other.0)
}
}

pub trait NullOrderCmp {
fn null_order_cmp(&self, other: &Self, nulls_last: bool) -> Ordering;
}

impl<T: TotalOrd> NullOrderCmp for Option<T> {
fn null_order_cmp(&self, other: &Self, nulls_last: bool) -> Ordering {
match (self, other) {
(None, None) => Ordering::Equal,
(None, Some(_)) => {
if nulls_last {
Ordering::Greater
} else {
Ordering::Less
}
},
(Some(_), None) => {
if nulls_last {
Ordering::Less
} else {
Ordering::Greater
}
},
(Some(l), Some(r)) => l.tot_cmp(r),
}
}
}

impl<T: TotalOrd> NullOrderCmp for NonNull<T> {
fn null_order_cmp(&self, other: &Self, _nulls_last: bool) -> Ordering {
self.0.tot_cmp(&other.0)
}
}

trait GetInner {
type Item;
Expand All @@ -29,16 +69,16 @@ impl<'a, T: StaticArray> GetInner for &'a T {
}

impl<'a, T: PolarsDataType> GetInner for NonNull<&'a ChunkedArray<T>> {
type Item = T::Physical<'a>;
type Item = NonNull<T::Physical<'a>>;
unsafe fn get_unchecked(&self, idx: usize) -> Self::Item {
self.0.value_unchecked(idx)
NonNull(self.0.value_unchecked(idx))
}
}

impl<'a, T: StaticArray> GetInner for NonNull<&'a T> {
type Item = T::ValueT<'a>;
type Item = NonNull<T::ValueT<'a>>;
unsafe fn get_unchecked(&self, idx: usize) -> Self::Item {
self.0.value_unchecked(idx)
NonNull(self.0.value_unchecked(idx))
}
}

Expand All @@ -51,7 +91,12 @@ pub trait TotalEqInner: Send + Sync {
pub trait TotalOrdInner: Send + Sync {
/// # Safety
/// Does not do any bound checks.
unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering;
unsafe fn cmp_element_unchecked(
&self,
idx_a: usize,
idx_b: usize,
nulls_last: bool,
) -> Ordering;
}

impl<T> TotalEqInner for T
Expand Down Expand Up @@ -102,13 +147,18 @@ where
impl<T> TotalOrdInner for T
where
T: GetInner + Send + Sync,
T::Item: TotalOrd,
T::Item: NullOrderCmp,
{
#[inline]
unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering {
unsafe fn cmp_element_unchecked(
&self,
idx_a: usize,
idx_b: usize,
nulls_last: bool,
) -> Ordering {
let a = self.get_unchecked(idx_a);
let b = self.get_unchecked(idx_b);
a.tot_cmp(&b)
a.null_order_cmp(&b, nulls_last)
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ pub struct SortOptions {
pub struct SortMultipleOptions {
pub other: Vec<Series>,
pub descending: Vec<bool>,
pub nulls_last: bool,
pub multithreaded: bool,
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use compare_inner::NullOrderCmp;
use polars_row::{convert_columns, RowsEncoded, SortField};
use polars_utils::iter::EnumerateIdxTrait;

Expand All @@ -21,7 +22,7 @@ pub(crate) fn args_validate<T: PolarsDataType>(
Ok(())
}

pub(crate) fn arg_sort_multiple_impl<T: TotalOrd + Send + Copy>(
pub(crate) fn arg_sort_multiple_impl<T: NullOrderCmp + Send + Copy>(
mut vals: Vec<(IdxSize, T)>,
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
Expand All @@ -36,7 +37,12 @@ pub(crate) fn arg_sort_multiple_impl<T: TotalOrd + Send + Copy>(
let first_descending = descending[0];
POOL.install(|| {
vals.par_sort_by(|tpl_a, tpl_b| {
match (first_descending, tpl_a.1.tot_cmp(&tpl_b.1)) {
match (
first_descending,
tpl_a
.1
.null_order_cmp(&tpl_b.1, options.nulls_last ^ first_descending),
) {
// if ordering is equal, we check the other arrays until we find a non-equal ordering
// if we have exhausted all arrays, we keep the equal ordering.
(_, Ordering::Equal) => {
Expand All @@ -46,6 +52,7 @@ pub(crate) fn arg_sort_multiple_impl<T: TotalOrd + Send + Copy>(
ordering_other_columns(
&compare_inner,
descending.get_unchecked(1..),
options.nulls_last,
idx_a,
idx_b,
)
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt;
use arrow::bitmap::MutableBitmap;
use arrow::buffer::Buffer;
use arrow::legacy::trusted_len::TrustedLenPush;
use compare_inner::NonNull;
use rayon::prelude::*;
pub use slice::*;

Expand Down Expand Up @@ -221,7 +222,7 @@ fn arg_sort_multiple_numeric<T: PolarsNumericType>(
vals.extend_trusted_len(arr.values().as_slice().iter().map(|v| {
let i = count;
count += 1;
(i, *v)
(i, NonNull(*v))
}))
}
arg_sort_multiple_impl(vals, options)
Expand Down Expand Up @@ -269,13 +270,14 @@ where
fn ordering_other_columns<'a>(
compare_inner: &'a [Box<dyn TotalOrdInner + 'a>],
descending: &[bool],
nulls_last: bool,
idx_a: usize,
idx_b: usize,
) -> Ordering {
for (cmp, descending) in compare_inner.iter().zip(descending) {
// SAFETY:
// indices are in bounds
let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b) };
let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, nulls_last ^ descending) };
match (ordering, descending) {
(Ordering::Equal, _) => continue,
(_, true) => return ordering.reverse(),
Expand Down
1 change: 1 addition & 0 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1860,6 +1860,7 @@ impl DataFrame {
let options = SortMultipleOptions {
other,
descending,
nulls_last,
multithreaded: parallel,
};
first.arg_sort_multiple(&options)?
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-lazy/src/physical_plan/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ fn sort_by_groups_multiple_by(
let options = SortMultipleOptions {
other: groups[1..].to_vec(),
descending: descending.to_owned(),
nulls_last: false,
multithreaded: false,
};

Expand All @@ -150,6 +151,7 @@ fn sort_by_groups_multiple_by(
let options = SortMultipleOptions {
other: groups[1..].to_vec(),
descending: descending.to_owned(),
nulls_last: false,
multithreaded: false,
};
let sorted_idx = groups[0].arg_sort_multiple(&options).unwrap();
Expand Down Expand Up @@ -197,6 +199,7 @@ impl PhysicalExpr for SortByExpr {
let options = SortMultipleOptions {
other: s_sort_by[1..].to_vec(),
descending,
nulls_last: false,
multithreaded: true,
};
s_sort_by[0].arg_sort_multiple(&options)
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-row/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ pub unsafe fn decode_rows(
}

unsafe fn decode(rows: &mut [&[u8]], field: &SortField, data_type: &ArrowDataType) -> ArrayRef {
// not yet supported for fixed types
assert!(!field.nulls_last, "not yet supported");
match data_type {
ArrowDataType::Null => NullArray::new(ArrowDataType::Null, rows.len()).to_boxed(),
ArrowDataType::Boolean => decode_bool(rows, field).to_boxed(),
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,3 +982,25 @@ def test_sorted_flag_concat_unit(unit_descending: bool) -> None:
out = pl.concat((b, a))
assert out.to_list() == [3, 2, 1, None]
assert out.flags["SORTED_DESC"]


@pytest.mark.parametrize("descending", [True, False])
@pytest.mark.parametrize("nulls_last", [True, False])
def test_sort_descending_nulls_last(descending: bool, nulls_last: bool) -> None:
df = pl.DataFrame({"x": [1, 3, None, 2, None], "y": [1, 3, 0, 2, 0]})

null_sentinel = 100 if descending ^ nulls_last else -100
ref_x = [1, 3, None, 2, None]
ref_x.sort(key=lambda k: null_sentinel if k is None else k, reverse=descending)
ref_y = [1, 3, 0, 2, 0]
ref_y.sort(key=lambda k: null_sentinel if k == 0 else k, reverse=descending)

assert_frame_equal(
df.sort("x", descending=descending, nulls_last=nulls_last),
pl.DataFrame({"x": ref_x, "y": ref_y}),
)

assert_frame_equal(
df.sort(["x", "y"], descending=descending, nulls_last=nulls_last),
pl.DataFrame({"x": ref_x, "y": ref_y}),
)
26 changes: 26 additions & 0 deletions py-polars/tests/unit/streaming/test_streaming_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,29 @@ def test_nulls_last_streaming_sort() -> None:
assert pl.LazyFrame({"x": [1, None]}).sort("x", nulls_last=True).collect(
streaming=True
).to_dict(as_series=False) == {"x": [1, None]}


@pytest.mark.parametrize("descending", [True, False])
@pytest.mark.parametrize("nulls_last", [True, False])
def test_sort_descending_nulls_last(descending: bool, nulls_last: bool) -> None:
df = pl.DataFrame({"x": [1, 3, None, 2, None], "y": [1, 3, 0, 2, 0]})

null_sentinel = 100 if descending ^ nulls_last else -100
ref_x = [1, 3, None, 2, None]
ref_x.sort(key=lambda k: null_sentinel if k is None else k, reverse=descending)
ref_y = [1, 3, 0, 2, 0]
ref_y.sort(key=lambda k: null_sentinel if k == 0 else k, reverse=descending)

assert_frame_equal(
df.lazy()
.sort("x", descending=descending, nulls_last=nulls_last)
.collect(streaming=True),
pl.DataFrame({"x": ref_x, "y": ref_y}),
)

assert_frame_equal(
df.lazy()
.sort(["x", "y"], descending=descending, nulls_last=nulls_last)
.collect(streaming=True),
pl.DataFrame({"x": ref_x, "y": ref_y}),
)

0 comments on commit 474ac34

Please sign in to comment.