Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Return largest non-NaN value for max() on sorted float arrays if it exists instead of NaN #15060

Merged
merged 11 commits into from
Mar 15, 2024
85 changes: 48 additions & 37 deletions crates/polars-core/src/chunked_array/ops/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ use polars_utils::min_max::MinMax;
pub use quantile::*;
pub use var::*;

use super::float_sorted_arg_max::{
float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
};
use crate::chunked_array::ChunkedArray;
use crate::datatypes::{BooleanChunked, PolarsNumericType};
use crate::prelude::*;
Expand Down Expand Up @@ -93,21 +96,18 @@ where
}

fn min(&self) -> Option<T::Native> {
if self.is_empty() {
if self.null_count() == self.len() {
return None;
}
// There is at least one non-null value.
match self.is_sorted_flag() {
IsSorted::Ascending => {
self.first_non_null().and_then(|idx| {
// SAFETY: first_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
})
let idx = self.first_non_null().unwrap();
unsafe { self.get_unchecked(idx) }
},
IsSorted::Descending => {
self.last_non_null().and_then(|idx| {
// SAFETY: last returns in bound index.
unsafe { self.get_unchecked(idx) }
})
let idx = self.last_non_null().unwrap();
unsafe { self.get_unchecked(idx) }
},
IsSorted::Not => self
.downcast_iter()
Expand All @@ -117,23 +117,28 @@ where
}

fn max(&self) -> Option<T::Native> {
if self.is_empty() {
if self.null_count() == self.len() {
return None;
}
// There is at least one non-null value.
match self.is_sorted_flag() {
IsSorted::Ascending => {
self.last_non_null().and_then(|idx| {
// SAFETY:
// last_non_null returns in bound index
unsafe { self.get_unchecked(idx) }
})
let idx = if T::get_dtype().is_float() {
float_arg_max_sorted_ascending(self)
} else {
self.last_non_null().unwrap()
};

unsafe { self.get_unchecked(idx) }
},
IsSorted::Descending => {
self.first_non_null().and_then(|idx| {
// SAFETY:
// first_non_null returns in bound index
unsafe { self.get_unchecked(idx) }
})
let idx = if T::get_dtype().is_float() {
float_arg_max_sorted_descending(self)
} else {
self.first_non_null().unwrap()
};

unsafe { self.get_unchecked(idx) }
},
IsSorted::Not => self
.downcast_iter()
Expand All @@ -143,30 +148,36 @@ where
}

fn min_max(&self) -> Option<(T::Native, T::Native)> {
if self.is_empty() {
if self.null_count() == self.len() {
return None;
}
// There is at least one non-null value.
match self.is_sorted_flag() {
IsSorted::Ascending => {
let min = self.first_non_null().and_then(|idx| {
// SAFETY: first_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
});
let max = self.last_non_null().and_then(|idx| {
// SAFETY: last_non_null returns in bound index.
let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
let max = {
let idx = if T::get_dtype().is_float() {
float_arg_max_sorted_ascending(self)
} else {
self.last_non_null().unwrap()
};

unsafe { self.get_unchecked(idx) }
});
};
min.zip(max)
},
IsSorted::Descending => {
let max = self.first_non_null().and_then(|idx| {
// SAFETY: first_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
});
let min = self.last_non_null().and_then(|idx| {
// SAFETY: last_non_null returns in bound index.
let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
let max = {
let idx = if T::get_dtype().is_float() {
float_arg_max_sorted_descending(self)
} else {
self.first_non_null().unwrap()
};

unsafe { self.get_unchecked(idx) }
});
};

min.zip(max)
},
IsSorted::Not => self
Expand All @@ -182,10 +193,10 @@ where
}

fn mean(&self) -> Option<f64> {
if self.is_empty() || self.null_count() == self.len() {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by, the is_empty() call is redundant

if self.null_count() == self.len() {
return None;
}
match self.dtype() {
match T::get_dtype() {
DataType::Float64 => {
let len = (self.len() - self.null_count()) as f64;
self.sum().map(|v| v.to_f64().unwrap() / len)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//! Implementations of the ChunkAgg trait.
use num_traits::Float;

use self::search_sorted::{
binary_search_array, slice_sorted_non_null_and_offset, SearchSortedSide,
};
use crate::prelude::*;

impl<T> ChunkedArray<T>
where
T: PolarsFloatType,
T::Native: Float,
{
fn float_arg_max_sorted_ascending(&self) -> usize {
let ca = self;
debug_assert!(ca.is_sorted_ascending_flag());
let is_descending = false;
let side = SearchSortedSide::Left;

let maybe_max_idx = ca.last_non_null().unwrap();

let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) };
if !maybe_max.is_nan() {
return maybe_max_idx;
}

let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) };
let arr = unsafe { ca.downcast_get_unchecked(0) };
let search_val = T::Native::nan();
let idx = binary_search_array(side, arr, search_val, is_descending) as usize;

let idx = idx.saturating_sub(1);

offset + idx
}

fn float_arg_max_sorted_descending(&self) -> usize {
let ca = self;
debug_assert!(ca.is_sorted_descending_flag());
let is_descending = true;
let side = SearchSortedSide::Right;

let maybe_max_idx = ca.first_non_null().unwrap();

let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) };
if !maybe_max.is_nan() {
return maybe_max_idx;
}

let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) };
let arr = unsafe { ca.downcast_get_unchecked(0) };
let search_val = T::Native::nan();
let idx = binary_search_array(side, arr, search_val, is_descending) as usize;

let idx = if idx == arr.len() { idx - 1 } else { idx };

offset + idx
}
}

/// # Safety
/// `ca` has a float dtype, has at least 1 non-null value and is sorted ascending
pub fn float_arg_max_sorted_ascending<T>(ca: &ChunkedArray<T>) -> usize
where
T: PolarsNumericType,
{
with_match_physical_float_polars_type!(ca.dtype(), |$T| {
let ca: &ChunkedArray<$T> = unsafe {
&*(ca as *const ChunkedArray<T> as *const ChunkedArray<$T>)
};
ca.float_arg_max_sorted_ascending()
})
}

/// # Safety
/// `ca` has a float dtype, has at least 1 non-null value and is sorted descending
pub fn float_arg_max_sorted_descending<T>(ca: &ChunkedArray<T>) -> usize
where
T: PolarsNumericType,
{
with_match_physical_float_polars_type!(ca.dtype(), |$T| {
let ca: &ChunkedArray<$T> = unsafe {
&*(ca as *const ChunkedArray<T> as *const ChunkedArray<$T>)
};
ca.float_arg_max_sorted_descending()
})
}
2 changes: 2 additions & 0 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod explode_and_offsets;
mod extend;
pub mod fill_null;
mod filter;
pub mod float_sorted_arg_max;
mod for_each;
pub mod full;
pub mod gather;
Expand All @@ -29,6 +30,7 @@ pub(crate) mod min_max_binary;
pub(crate) mod nulls;
mod reverse;
pub(crate) mod rolling_window;
pub mod search_sorted;
mod set;
mod shift;
pub mod sort;
Expand Down
128 changes: 128 additions & 0 deletions crates/polars-core/src/chunked_array/ops/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::cmp::Ordering;
use std::fmt::Debug;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use crate::prelude::*;

#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum SearchSortedSide {
#[default]
Any,
Left,
Right,
}

/// Search the left or right index that still fulfills the requirements.
fn get_side_idx<'a, A>(side: SearchSortedSide, mid: IdxSize, arr: &'a A, len: usize) -> IdxSize
where
A: StaticArray,
A::ValueT<'a>: TotalOrd + Debug + Copy,
{
let mut mid = mid;

// approach the boundary from any side
// this is O(n) we could make this binary search later
match side {
SearchSortedSide::Any => mid,
SearchSortedSide::Left => {
if mid as usize == len {
mid -= 1;
}

let current = unsafe { arr.get_unchecked(mid as usize) };
loop {
if mid == 0 {
return mid;
}
mid -= 1;
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) {
return mid + 1;
}
}
},
SearchSortedSide::Right => {
if mid as usize == len {
return mid;
}
let current = unsafe { arr.get_unchecked(mid as usize) };
let bound = (len - 1) as IdxSize;
loop {
if mid >= bound {
return mid + 1;
}
mid += 1;
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) {
return mid;
}
}
},
}
}

pub fn binary_search_array<'a, A>(
side: SearchSortedSide,
arr: &'a A,
search_value: A::ValueT<'a>,
descending: bool,
) -> IdxSize
where
A: StaticArray,
A::ValueT<'a>: TotalOrd + Debug + Copy,
{
let mut size = arr.len() as IdxSize;
let mut left = 0 as IdxSize;
let mut right = size;
while left < right {
let mid = left + size / 2;

// SAFETY: the call is made safe by the following invariants:
// - `mid >= 0`
// - `mid < size`: `mid` is limited by `[left; right)` bound.
let cmp = match unsafe { arr.get_unchecked(mid as usize) } {
None => Ordering::Less,
Some(value) => {
if descending {
search_value.tot_cmp(&value)
} else {
value.tot_cmp(&search_value)
}
},
};

// The reason why we use if/else control flow rather than match
// is because match reorders comparison operations, which is perf sensitive.
// This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra.
if cmp == Ordering::Less {
left = mid + 1;
} else if cmp == Ordering::Greater {
right = mid;
} else {
return get_side_idx(side, mid, arr, arr.len());
}

size = right - left;
}

left
}

/// Get a slice of the non-null values of a sorted array. The returned array
/// will have a single chunk.
/// # Safety
/// The array is sorted and has at least one non-null value.
pub unsafe fn slice_sorted_non_null_and_offset<T>(ca: &ChunkedArray<T>) -> (usize, ChunkedArray<T>)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this for now because binary search does not yet work for arrays sorted with nulls last (#15045)

where
T: PolarsDataType,
{
let offset = ca.first_non_null().unwrap();
let length = 1 + ca.last_non_null().unwrap() - offset;
let out = ca.slice(offset as i64, length);

debug_assert!(out.null_count() != out.len());
debug_assert!(out.null_count() == 0);

(offset, out.rechunk())
}
Loading
Loading