Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RunEndArray filter #1380

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 13 additions & 49 deletions encodings/runend/src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use std::cmp::min;

use arrow_buffer::BooleanBufferBuilder;
use itertools::Itertools;
use num_traits::{AsPrimitive, FromPrimitive};
use vortex_array::array::{BoolArray, BooleanBuffer, PrimitiveArray};
use vortex_array::validity::Validity;
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::ArrayDType;
use vortex_dtype::{match_each_integer_ptype, match_each_native_ptype, NativePType, Nullability};
use vortex_error::{vortex_panic, VortexResult};
use vortex_error::VortexResult;

use crate::iter::trimmed_ends_iter;

pub fn runend_encode(array: &PrimitiveArray) -> (PrimitiveArray, PrimitiveArray) {
let validity = if array.dtype().nullability() == Nullability::NonNullable {
Expand Down Expand Up @@ -61,9 +60,8 @@ pub fn runend_decode_primitive(
match_each_native_ptype!(values.ptype(), |$P| {
match_each_integer_ptype!(ends.ptype(), |$E| {
Ok(PrimitiveArray::from_vec(runend_decode_typed_primitive(
ends.maybe_null_slice::<$E>(),
trimmed_ends_iter(ends.maybe_null_slice::<$E>(), offset, length),
values.maybe_null_slice::<$P>(),
offset,
length,
), validity))
})
Expand All @@ -79,67 +77,33 @@ pub fn runend_decode_bools(
) -> VortexResult<BoolArray> {
match_each_integer_ptype!(ends.ptype(), |$E| {
BoolArray::try_new(runend_decode_typed_bool(
ends.maybe_null_slice::<$E>(),
trimmed_ends_iter(ends.maybe_null_slice::<$E>(), offset, length),
values.boolean_buffer(),
offset,
length,
), validity)
})
}

#[inline]
fn trimmed_run_ends<E: NativePType + AsPrimitive<usize> + FromPrimitive + Ord>(
run_ends: &[E],
offset: usize,
length: usize,
) -> impl Iterator<Item = E> + use<'_, E> {
let offset_e = E::from_usize(offset).unwrap_or_else(|| {
vortex_panic!(
"offset {} cannot be converted to {}",
offset,
std::any::type_name::<E>()
)
});
let length_e = E::from_usize(length).unwrap_or_else(|| {
vortex_panic!(
"length {} cannot be converted to {}",
length,
std::any::type_name::<E>()
)
});
run_ends
.iter()
.map(move |&v| v - offset_e)
.map(move |v| min(v, length_e))
}

pub fn runend_decode_typed_primitive<
E: NativePType + AsPrimitive<usize> + FromPrimitive + Ord,
T: NativePType,
>(
run_ends: &[E],
pub fn runend_decode_typed_primitive<T: NativePType>(
run_ends: impl Iterator<Item = usize>,
values: &[T],
offset: usize,
length: usize,
) -> Vec<T> {
let trimmed_ends = trimmed_run_ends(run_ends, offset, length);
let mut decoded = Vec::with_capacity(length);
for (end, value) in trimmed_ends.zip_eq(values) {
decoded.extend(std::iter::repeat_n(value, end.as_() - decoded.len()));
for (end, value) in run_ends.zip_eq(values) {
decoded.extend(std::iter::repeat_n(value, end - decoded.len()));
}
decoded
}

pub fn runend_decode_typed_bool<E: NativePType + AsPrimitive<usize> + FromPrimitive + Ord>(
run_ends: &[E],
pub fn runend_decode_typed_bool(
run_ends: impl Iterator<Item = usize>,
values: BooleanBuffer,
offset: usize,
length: usize,
) -> BooleanBuffer {
let trimmed_ends = trimmed_run_ends(run_ends, offset, length);
let mut decoded = BooleanBufferBuilder::new(length);
for (end, value) in trimmed_ends.zip_eq(values.iter()) {
decoded.append_n(end.as_() - decoded.len(), value);
for (end, value) in run_ends.zip_eq(values.iter()) {
decoded.append_n(end - decoded.len(), value);
}
decoded.finish()
}
Expand Down
54 changes: 42 additions & 12 deletions encodings/runend/src/compute.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::min;
use std::ops::AddAssign;

use num_traits::AsPrimitive;
Expand Down Expand Up @@ -76,9 +77,8 @@ impl TakeFn for RunEndArray {
let primitive_indices = indices.clone().into_primitive()?;
let u64_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| {
primitive_indices
.maybe_null_slice::<$P>()
.iter()
.copied()
.into_maybe_null_slice::<$P>()
.into_iter()
.map(|idx| {
let usize_idx = idx as usize;
if usize_idx >= self.len() {
Expand All @@ -89,11 +89,11 @@ impl TakeFn for RunEndArray {
})
.collect::<VortexResult<Vec<u64>>>()?
});
let physical_indices: Vec<u64> = self
let physical_indices = self
.find_physical_indices(&u64_indices)?
.iter()
.map(|idx| *idx as u64)
.collect();
.into_iter()
.map(|idx| idx as u64)
.collect::<Vec<_>>();
let physical_indices_array = PrimitiveArray::from(physical_indices).into_array();
let dense_values = take(self.values(), &physical_indices_array, options)?;

Expand Down Expand Up @@ -146,12 +146,12 @@ impl SliceFn for RunEndArray {

impl FilterFn for RunEndArray {
fn filter(&self, mask: FilterMask) -> VortexResult<ArrayData> {
let validity = self.validity().filter(&mask)?;
let primitive_run_ends = self.ends().into_primitive()?;
let (run_ends, mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| {
filter_run_ends(primitive_run_ends.maybe_null_slice::<$P>(), mask)?
let (run_ends, values_mask) = match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |$P| {
filter_run_ends(primitive_run_ends.maybe_null_slice::<$P>(), self.offset() as u64, self.len() as u64, mask)?
});
let validity = self.validity().filter(&mask)?;
let values = filter(&self.values(), mask)?;
let values = filter(&self.values(), values_mask)?;

RunEndArray::try_new(run_ends.into_array(), values, validity).map(|a| a.into_array())
}
Expand All @@ -160,6 +160,8 @@ impl FilterFn for RunEndArray {
// Code adapted from apache arrow-rs https://github.com/apache/arrow-rs/blob/b1f5c250ebb6c1252b4e7c51d15b8e77f4c361fa/arrow-select/src/filter.rs#L425
fn filter_run_ends<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(
run_ends: &[R],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried converting this to impl Iterator<Item = usize> but still need to pass the length and would needlesly widen type of the ends.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do impl Iterator + ExactSizeIterator

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oooo, let me try that

offset: u64,
length: u64,
mask: FilterMask,
) -> VortexResult<(PrimitiveArray, FilterMask)> {
let mut new_run_ends = vec![R::zero(); run_ends.len()];
Expand All @@ -171,7 +173,7 @@ fn filter_run_ends<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(

let new_mask: FilterMask = BooleanBuffer::collect_bool(run_ends.len(), |i| {
let mut keep = false;
let end = run_ends[i].as_();
let end = min(run_ends[i].as_() - offset, length);

// Safety: predicate must be the same length as the array the ends have been taken from
for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
Expand Down Expand Up @@ -464,6 +466,34 @@ mod test {
);
}

#[test]
fn filter_sliced_run_end() {
let arr = slice(ree_array(), 2, 7).unwrap();
let filtered = filter(
&arr,
FilterMask::from_iter([true, false, false, true, true]),
)
.unwrap();
let filtered_run_end = RunEndArray::try_from(filtered).unwrap();

assert_eq!(
filtered_run_end
.ends()
.into_primitive()
.unwrap()
.maybe_null_slice::<u64>(),
[1, 2, 3]
);
assert_eq!(
filtered_run_end
.values()
.into_primitive()
.unwrap()
.maybe_null_slice::<i32>(),
[1, 4, 2]
);
}

#[test]
fn compare_run_end() {
let arr = ree_array();
Expand Down
33 changes: 33 additions & 0 deletions encodings/runend/src/iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use std::cmp::min;

use num_traits::{AsPrimitive, FromPrimitive};
use vortex_dtype::NativePType;
use vortex_error::vortex_panic;

#[inline]
pub fn trimmed_ends_iter<E: NativePType + FromPrimitive + AsPrimitive<usize> + Ord>(
run_ends: &[E],
offset: usize,
length: usize,
) -> impl Iterator<Item = usize> + use<'_, E> {
let offset_e = E::from_usize(offset).unwrap_or_else(|| {
vortex_panic!(
"offset {} cannot be converted to {}",
offset,
std::any::type_name::<E>()
)
});
let length_e = E::from_usize(length).unwrap_or_else(|| {
vortex_panic!(
"length {} cannot be converted to {}",
length,
std::any::type_name::<E>()
)
});
run_ends
.iter()
.copied()
.map(move |v| v - offset_e)
.map(move |v| min(v, length_e))
.map(|v| v.as_())
}
1 change: 1 addition & 0 deletions encodings/runend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pub use array::*;
mod array;
pub mod compress;
mod compute;
mod iter;
10 changes: 9 additions & 1 deletion fuzz/fuzz_targets/array_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,15 @@ fn assert_search_sorted(
}

fn assert_array_eq(lhs: &ArrayData, rhs: &ArrayData, step: usize) {
assert_eq!(lhs.len(), rhs.len());
assert_eq!(
lhs.len(),
rhs.len(),
"LHS len {} != RHS len {}, lhs is {} rhs is {} in step {step}",
lhs.len(),
rhs.len(),
lhs.encoding().id(),
rhs.encoding().id()
);
for idx in 0..lhs.len() {
let l = scalar_at(lhs, idx).unwrap();
let r = scalar_at(rhs, idx).unwrap();
Expand Down
Loading