-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Return largest non-NaN value for
max()
on sorted float arrays …
…if it exists instead of NaN (#15060)
- Loading branch information
1 parent
0abbe5c
commit ec04150
Showing
8 changed files
with
329 additions
and
183 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
//! Implementations of the ChunkAgg trait. | ||
use num_traits::Float; | ||
|
||
use self::search_sorted::{ | ||
binary_search_array, slice_sorted_non_null_and_offset, SearchSortedSide, | ||
}; | ||
use crate::prelude::*; | ||
|
||
impl<T> ChunkedArray<T> | ||
where | ||
T: PolarsFloatType, | ||
T::Native: Float, | ||
{ | ||
fn float_arg_max_sorted_ascending(&self) -> usize { | ||
let ca = self; | ||
debug_assert!(ca.is_sorted_ascending_flag()); | ||
let is_descending = false; | ||
let side = SearchSortedSide::Left; | ||
|
||
let maybe_max_idx = ca.last_non_null().unwrap(); | ||
|
||
let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) }; | ||
if !maybe_max.is_nan() { | ||
return maybe_max_idx; | ||
} | ||
|
||
let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) }; | ||
let arr = unsafe { ca.downcast_get_unchecked(0) }; | ||
let search_val = T::Native::nan(); | ||
let idx = binary_search_array(side, arr, search_val, is_descending) as usize; | ||
|
||
let idx = idx.saturating_sub(1); | ||
|
||
offset + idx | ||
} | ||
|
||
fn float_arg_max_sorted_descending(&self) -> usize { | ||
let ca = self; | ||
debug_assert!(ca.is_sorted_descending_flag()); | ||
let is_descending = true; | ||
let side = SearchSortedSide::Right; | ||
|
||
let maybe_max_idx = ca.first_non_null().unwrap(); | ||
|
||
let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) }; | ||
if !maybe_max.is_nan() { | ||
return maybe_max_idx; | ||
} | ||
|
||
let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) }; | ||
let arr = unsafe { ca.downcast_get_unchecked(0) }; | ||
let search_val = T::Native::nan(); | ||
let idx = binary_search_array(side, arr, search_val, is_descending) as usize; | ||
|
||
let idx = if idx == arr.len() { idx - 1 } else { idx }; | ||
|
||
offset + idx | ||
} | ||
} | ||
|
||
/// # Safety | ||
/// `ca` has a float dtype, has at least 1 non-null value and is sorted ascending | ||
pub fn float_arg_max_sorted_ascending<T>(ca: &ChunkedArray<T>) -> usize | ||
where | ||
T: PolarsNumericType, | ||
{ | ||
with_match_physical_float_polars_type!(ca.dtype(), |$T| { | ||
let ca: &ChunkedArray<$T> = unsafe { | ||
&*(ca as *const ChunkedArray<T> as *const ChunkedArray<$T>) | ||
}; | ||
ca.float_arg_max_sorted_ascending() | ||
}) | ||
} | ||
|
||
/// # Safety | ||
/// `ca` has a float dtype, has at least 1 non-null value and is sorted descending | ||
pub fn float_arg_max_sorted_descending<T>(ca: &ChunkedArray<T>) -> usize | ||
where | ||
T: PolarsNumericType, | ||
{ | ||
with_match_physical_float_polars_type!(ca.dtype(), |$T| { | ||
let ca: &ChunkedArray<$T> = unsafe { | ||
&*(ca as *const ChunkedArray<T> as *const ChunkedArray<$T>) | ||
}; | ||
ca.float_arg_max_sorted_descending() | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
crates/polars-core/src/chunked_array/ops/search_sorted.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
use std::cmp::Ordering; | ||
use std::fmt::Debug; | ||
|
||
#[cfg(feature = "serde")] | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use crate::prelude::*; | ||
|
||
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)] | ||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] | ||
pub enum SearchSortedSide { | ||
#[default] | ||
Any, | ||
Left, | ||
Right, | ||
} | ||
|
||
/// Search the left or right index that still fulfills the requirements. | ||
fn get_side_idx<'a, A>(side: SearchSortedSide, mid: IdxSize, arr: &'a A, len: usize) -> IdxSize | ||
where | ||
A: StaticArray, | ||
A::ValueT<'a>: TotalOrd + Debug + Copy, | ||
{ | ||
let mut mid = mid; | ||
|
||
// approach the boundary from any side | ||
// this is O(n) we could make this binary search later | ||
match side { | ||
SearchSortedSide::Any => mid, | ||
SearchSortedSide::Left => { | ||
if mid as usize == len { | ||
mid -= 1; | ||
} | ||
|
||
let current = unsafe { arr.get_unchecked(mid as usize) }; | ||
loop { | ||
if mid == 0 { | ||
return mid; | ||
} | ||
mid -= 1; | ||
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) { | ||
return mid + 1; | ||
} | ||
} | ||
}, | ||
SearchSortedSide::Right => { | ||
if mid as usize == len { | ||
return mid; | ||
} | ||
let current = unsafe { arr.get_unchecked(mid as usize) }; | ||
let bound = (len - 1) as IdxSize; | ||
loop { | ||
if mid >= bound { | ||
return mid + 1; | ||
} | ||
mid += 1; | ||
if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) { | ||
return mid; | ||
} | ||
} | ||
}, | ||
} | ||
} | ||
|
||
pub fn binary_search_array<'a, A>( | ||
side: SearchSortedSide, | ||
arr: &'a A, | ||
search_value: A::ValueT<'a>, | ||
descending: bool, | ||
) -> IdxSize | ||
where | ||
A: StaticArray, | ||
A::ValueT<'a>: TotalOrd + Debug + Copy, | ||
{ | ||
let mut size = arr.len() as IdxSize; | ||
let mut left = 0 as IdxSize; | ||
let mut right = size; | ||
while left < right { | ||
let mid = left + size / 2; | ||
|
||
// SAFETY: the call is made safe by the following invariants: | ||
// - `mid >= 0` | ||
// - `mid < size`: `mid` is limited by `[left; right)` bound. | ||
let cmp = match unsafe { arr.get_unchecked(mid as usize) } { | ||
None => Ordering::Less, | ||
Some(value) => { | ||
if descending { | ||
search_value.tot_cmp(&value) | ||
} else { | ||
value.tot_cmp(&search_value) | ||
} | ||
}, | ||
}; | ||
|
||
// The reason why we use if/else control flow rather than match | ||
// is because match reorders comparison operations, which is perf sensitive. | ||
// This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra. | ||
if cmp == Ordering::Less { | ||
left = mid + 1; | ||
} else if cmp == Ordering::Greater { | ||
right = mid; | ||
} else { | ||
return get_side_idx(side, mid, arr, arr.len()); | ||
} | ||
|
||
size = right - left; | ||
} | ||
|
||
left | ||
} | ||
|
||
/// Get a slice of the non-null values of a sorted array. The returned array | ||
/// will have a single chunk. | ||
/// # Safety | ||
/// The array is sorted and has at least one non-null value. | ||
pub unsafe fn slice_sorted_non_null_and_offset<T>(ca: &ChunkedArray<T>) -> (usize, ChunkedArray<T>) | ||
where | ||
T: PolarsDataType, | ||
{ | ||
let offset = ca.first_non_null().unwrap(); | ||
let length = 1 + ca.last_non_null().unwrap() - offset; | ||
let out = ca.slice(offset as i64, length); | ||
|
||
debug_assert!(out.null_count() != out.len()); | ||
debug_assert!(out.null_count() == 0); | ||
|
||
(offset, out.rechunk()) | ||
} |
Oops, something went wrong.