Skip to content

Commit

Permalink
fix: Return largest non-NaN value for max() on sorted float arrays …
Browse files Browse the repository at this point in the history
…if it exists instead of NaN (#15060)
  • Loading branch information
nameexhaustion authored Mar 15, 2024
1 parent 0abbe5c commit ec04150
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 183 deletions.
85 changes: 48 additions & 37 deletions crates/polars-core/src/chunked_array/ops/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ use polars_utils::min_max::MinMax;
pub use quantile::*;
pub use var::*;

use super::float_sorted_arg_max::{
float_arg_max_sorted_ascending, float_arg_max_sorted_descending,
};
use crate::chunked_array::ChunkedArray;
use crate::datatypes::{BooleanChunked, PolarsNumericType};
use crate::prelude::*;
Expand Down Expand Up @@ -93,21 +96,18 @@ where
}

fn min(&self) -> Option<T::Native> {
if self.is_empty() {
if self.null_count() == self.len() {
return None;
}
// There is at least one non-null value.
match self.is_sorted_flag() {
IsSorted::Ascending => {
self.first_non_null().and_then(|idx| {
// SAFETY: first_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
})
let idx = self.first_non_null().unwrap();
unsafe { self.get_unchecked(idx) }
},
IsSorted::Descending => {
self.last_non_null().and_then(|idx| {
// SAFETY: last returns in bound index.
unsafe { self.get_unchecked(idx) }
})
let idx = self.last_non_null().unwrap();
unsafe { self.get_unchecked(idx) }
},
IsSorted::Not => self
.downcast_iter()
Expand All @@ -117,23 +117,28 @@ where
}

fn max(&self) -> Option<T::Native> {
if self.is_empty() {
if self.null_count() == self.len() {
return None;
}
// There is at least one non-null value.
match self.is_sorted_flag() {
IsSorted::Ascending => {
self.last_non_null().and_then(|idx| {
// SAFETY:
// last_non_null returns in bound index
unsafe { self.get_unchecked(idx) }
})
let idx = if T::get_dtype().is_float() {
float_arg_max_sorted_ascending(self)
} else {
self.last_non_null().unwrap()
};

unsafe { self.get_unchecked(idx) }
},
IsSorted::Descending => {
self.first_non_null().and_then(|idx| {
// SAFETY:
// first_non_null returns in bound index
unsafe { self.get_unchecked(idx) }
})
let idx = if T::get_dtype().is_float() {
float_arg_max_sorted_descending(self)
} else {
self.first_non_null().unwrap()
};

unsafe { self.get_unchecked(idx) }
},
IsSorted::Not => self
.downcast_iter()
Expand All @@ -143,30 +148,36 @@ where
}

fn min_max(&self) -> Option<(T::Native, T::Native)> {
if self.is_empty() {
if self.null_count() == self.len() {
return None;
}
// There is at least one non-null value.
match self.is_sorted_flag() {
IsSorted::Ascending => {
let min = self.first_non_null().and_then(|idx| {
// SAFETY: first_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
});
let max = self.last_non_null().and_then(|idx| {
// SAFETY: last_non_null returns in bound index.
let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) };
let max = {
let idx = if T::get_dtype().is_float() {
float_arg_max_sorted_ascending(self)
} else {
self.last_non_null().unwrap()
};

unsafe { self.get_unchecked(idx) }
});
};
min.zip(max)
},
IsSorted::Descending => {
let max = self.first_non_null().and_then(|idx| {
// SAFETY: first_non_null returns in bound index.
unsafe { self.get_unchecked(idx) }
});
let min = self.last_non_null().and_then(|idx| {
// SAFETY: last_non_null returns in bound index.
let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) };
let max = {
let idx = if T::get_dtype().is_float() {
float_arg_max_sorted_descending(self)
} else {
self.first_non_null().unwrap()
};

unsafe { self.get_unchecked(idx) }
});
};

min.zip(max)
},
IsSorted::Not => self
Expand All @@ -182,10 +193,10 @@ where
}

fn mean(&self) -> Option<f64> {
if self.is_empty() || self.null_count() == self.len() {
if self.null_count() == self.len() {
return None;
}
match self.dtype() {
match T::get_dtype() {
DataType::Float64 => {
let len = (self.len() - self.null_count()) as f64;
self.sum().map(|v| v.to_f64().unwrap() / len)
Expand Down
87 changes: 87 additions & 0 deletions crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//! Implementations of the ChunkAgg trait.
use num_traits::Float;

use self::search_sorted::{
binary_search_array, slice_sorted_non_null_and_offset, SearchSortedSide,
};
use crate::prelude::*;

impl<T> ChunkedArray<T>
where
T: PolarsFloatType,
T::Native: Float,
{
fn float_arg_max_sorted_ascending(&self) -> usize {
let ca = self;
debug_assert!(ca.is_sorted_ascending_flag());
let is_descending = false;
let side = SearchSortedSide::Left;

let maybe_max_idx = ca.last_non_null().unwrap();

let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) };
if !maybe_max.is_nan() {
return maybe_max_idx;
}

let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) };
let arr = unsafe { ca.downcast_get_unchecked(0) };
let search_val = T::Native::nan();
let idx = binary_search_array(side, arr, search_val, is_descending) as usize;

let idx = idx.saturating_sub(1);

offset + idx
}

fn float_arg_max_sorted_descending(&self) -> usize {
let ca = self;
debug_assert!(ca.is_sorted_descending_flag());
let is_descending = true;
let side = SearchSortedSide::Right;

let maybe_max_idx = ca.first_non_null().unwrap();

let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) };
if !maybe_max.is_nan() {
return maybe_max_idx;
}

let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) };
let arr = unsafe { ca.downcast_get_unchecked(0) };
let search_val = T::Native::nan();
let idx = binary_search_array(side, arr, search_val, is_descending) as usize;

let idx = if idx == arr.len() { idx - 1 } else { idx };

offset + idx
}
}

/// # Safety
/// `ca` has a float dtype, has at least 1 non-null value and is sorted ascending
pub fn float_arg_max_sorted_ascending<T>(ca: &ChunkedArray<T>) -> usize
where
T: PolarsNumericType,
{
with_match_physical_float_polars_type!(ca.dtype(), |$T| {
let ca: &ChunkedArray<$T> = unsafe {
&*(ca as *const ChunkedArray<T> as *const ChunkedArray<$T>)
};
ca.float_arg_max_sorted_ascending()
})
}

/// # Safety
/// `ca` has a float dtype, has at least 1 non-null value and is sorted descending
pub fn float_arg_max_sorted_descending<T>(ca: &ChunkedArray<T>) -> usize
where
T: PolarsNumericType,
{
with_match_physical_float_polars_type!(ca.dtype(), |$T| {
let ca: &ChunkedArray<$T> = unsafe {
&*(ca as *const ChunkedArray<T> as *const ChunkedArray<$T>)
};
ca.float_arg_max_sorted_descending()
})
}
2 changes: 2 additions & 0 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod explode_and_offsets;
mod extend;
pub mod fill_null;
mod filter;
pub mod float_sorted_arg_max;
mod for_each;
pub mod full;
pub mod gather;
Expand All @@ -29,6 +30,7 @@ pub(crate) mod min_max_binary;
pub(crate) mod nulls;
mod reverse;
pub(crate) mod rolling_window;
pub mod search_sorted;
mod set;
mod shift;
pub mod sort;
Expand Down
128 changes: 128 additions & 0 deletions crates/polars-core/src/chunked_array/ops/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::cmp::Ordering;
use std::fmt::Debug;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use crate::prelude::*;

#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum SearchSortedSide {
#[default]
Any,
Left,
Right,
}

/// Search the left or right index that still fulfills the requirements.
fn get_side_idx<'a, A>(side: SearchSortedSide, mid: IdxSize, arr: &'a A, len: usize) -> IdxSize
where
A: StaticArray,
A::ValueT<'a>: TotalOrd + Debug + Copy,
{
let mut mid = mid;

// approach the boundary from any side
// this is O(n) we could make this binary search later
match side {
SearchSortedSide::Any => mid,
SearchSortedSide::Left => {
if mid as usize == len {
mid -= 1;
}

let current = unsafe { arr.get_unchecked(mid as usize) };
loop {
if mid == 0 {
return mid;
}
mid -= 1;
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) {
return mid + 1;
}
}
},
SearchSortedSide::Right => {
if mid as usize == len {
return mid;
}
let current = unsafe { arr.get_unchecked(mid as usize) };
let bound = (len - 1) as IdxSize;
loop {
if mid >= bound {
return mid + 1;
}
mid += 1;
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) {
return mid;
}
}
},
}
}

pub fn binary_search_array<'a, A>(
side: SearchSortedSide,
arr: &'a A,
search_value: A::ValueT<'a>,
descending: bool,
) -> IdxSize
where
A: StaticArray,
A::ValueT<'a>: TotalOrd + Debug + Copy,
{
let mut size = arr.len() as IdxSize;
let mut left = 0 as IdxSize;
let mut right = size;
while left < right {
let mid = left + size / 2;

// SAFETY: the call is made safe by the following invariants:
// - `mid >= 0`
// - `mid < size`: `mid` is limited by `[left; right)` bound.
let cmp = match unsafe { arr.get_unchecked(mid as usize) } {
None => Ordering::Less,
Some(value) => {
if descending {
search_value.tot_cmp(&value)
} else {
value.tot_cmp(&search_value)
}
},
};

// The reason why we use if/else control flow rather than match
// is because match reorders comparison operations, which is perf sensitive.
// This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra.
if cmp == Ordering::Less {
left = mid + 1;
} else if cmp == Ordering::Greater {
right = mid;
} else {
return get_side_idx(side, mid, arr, arr.len());
}

size = right - left;
}

left
}

/// Get a slice of the non-null values of a sorted array. The returned array
/// will have a single chunk.
/// # Safety
/// The array is sorted and has at least one non-null value.
pub unsafe fn slice_sorted_non_null_and_offset<T>(ca: &ChunkedArray<T>) -> (usize, ChunkedArray<T>)
where
T: PolarsDataType,
{
let offset = ca.first_non_null().unwrap();
let length = 1 + ca.last_non_null().unwrap() - offset;
let out = ca.slice(offset as i64, length);

debug_assert!(out.null_count() != out.len());
debug_assert!(out.null_count() == 0);

(offset, out.rechunk())
}
Loading

0 comments on commit ec04150

Please sign in to comment.