Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions vortex-compute/src/filter/bitbuffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
use vortex_buffer::{BitBuffer, BitBufferMut, get_bit};
use vortex_mask::{Mask, MaskIter};

use crate::filter::Filter;
use crate::filter::{Filter, MaskIndices};

/// If the filter density is above 80%, we use slices to filter the array instead of indices.
// TODO(ngates): we need more experimentation to determine the best threshold here.
const FILTER_SLICES_DENSITY_THRESHOLD: f64 = 0.8;

impl Filter for &BitBuffer {
impl Filter<Mask> for &BitBuffer {
type Output = BitBuffer;

fn filter(self, selection_mask: &Mask) -> BitBuffer {
Expand All @@ -33,6 +33,14 @@ impl Filter for &BitBuffer {
}
}

impl Filter<MaskIndices<'_>> for &BitBuffer {
type Output = BitBuffer;

fn filter(self, indices: &MaskIndices) -> BitBuffer {
filter_indices(self, indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

}
}

fn filter_indices(bools: &BitBuffer, indices: &[usize]) -> BitBuffer {
let buffer = bools.inner().as_ref();
BitBuffer::collect_bool(indices.len(), |idx| {
Expand Down
40 changes: 29 additions & 11 deletions vortex-compute/src/filter/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
use vortex_buffer::{Buffer, BufferMut};
use vortex_mask::{Mask, MaskIter};

use crate::filter::Filter;
use crate::filter::{Filter, MaskIndices};

// This is modeled after the constant with the equivalent name in arrow-rs.
const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;

impl<T: Copy> Filter for &Buffer<T> {
impl<T: Copy> Filter<Mask> for &Buffer<T> {
type Output = Buffer<T>;

fn filter(self, selection_mask: &Mask) -> Buffer<T> {
Expand All @@ -32,7 +32,15 @@ impl<T: Copy> Filter for &Buffer<T> {
}
}

impl<T: Copy> Filter for &mut BufferMut<T> {
impl<T: Copy> Filter<MaskIndices<'_>> for &Buffer<T> {
type Output = Buffer<T>;

fn filter(self, indices: &MaskIndices) -> Buffer<T> {
filter_indices(self, indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert that the lengths are the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the lengths won't be the same b/c it's an indices

}
}

impl<T: Copy> Filter<Mask> for &mut BufferMut<T> {
type Output = ();

fn filter(self, selection_mask: &Mask) {
Expand Down Expand Up @@ -69,16 +77,26 @@ impl<T: Copy> Filter for &mut BufferMut<T> {
}
}

impl<T: Copy> Filter for Buffer<T> {
type Output = Self;
impl<T: Copy> Filter<MaskIndices<'_>> for &mut BufferMut<T> {
type Output = ();

fn filter(self, selection_mask: &Mask) -> Self {
assert_eq!(
selection_mask.len(),
self.len(),
"Selection mask length must equal the buffer length"
);
fn filter(self, indices: &MaskIndices) -> Self::Output {
for (write_index, &read_index) in indices.iter().enumerate() {
self[write_index] = self[read_index];
}

self.truncate(indices.len());
}
}

impl<M, T: Copy> Filter<M> for Buffer<T>
where
for<'a> &'a Buffer<T>: Filter<M, Output = Buffer<T>>,
for<'a> &'a mut BufferMut<T>: Filter<M, Output = ()>,
{
type Output = Self;

fn filter(self, selection_mask: &M) -> Self {
// If we have exclusive access, we can perform the filter in place.
match self.try_into_mut() {
Ok(mut buffer_mut) => {
Expand Down
44 changes: 41 additions & 3 deletions vortex-compute/src/filter/mask.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_mask::Mask;
use vortex_mask::{Mask, MaskMut};

use crate::filter::Filter;
use crate::filter::{Filter, MaskIndices};

impl Filter for &Mask {
impl Filter<Mask> for &Mask {
type Output = Mask;

fn filter(self, selection_mask: &Mask) -> Mask {
Expand All @@ -27,3 +27,41 @@ impl Filter for &Mask {
}
}
}

impl Filter<MaskIndices<'_>> for &Mask {
type Output = Mask;

fn filter(self, indices: &MaskIndices<'_>) -> Mask {
match self {
Mask::AllTrue(_) => Mask::AllTrue(indices.len()),
Mask::AllFalse(_) => Mask::AllFalse(indices.len()),
Mask::Values(mask_values) => Mask::from(mask_values.bit_buffer().filter(indices)),
}
}
}

impl Filter<Mask> for &mut MaskMut {
type Output = ();

fn filter(self, selection_mask: &Mask) {
assert_eq!(
selection_mask.len(),
self.len(),
"Selection mask length must equal the mask length"
);

// TODO(connor): There is definitely a better way to do this (in place).
let filtered = self.clone().freeze().filter(selection_mask).into_mut();
*self = filtered;
}
}

impl Filter<MaskIndices<'_>> for &mut MaskMut {
type Output = ();

fn filter(self, indices: &MaskIndices<'_>) -> Self::Output {
// TODO(aduffy): Filter in-place
let filtered = self.clone().freeze().filter(indices).into_mut();
*self = filtered;
}
}
40 changes: 36 additions & 4 deletions vortex-compute/src/filter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

//! Filter function.

use std::ops::Deref;

mod bitbuffer;
mod buffer;
mod mask;
mod vector;

use vortex_mask::Mask;

/// Function for filtering based on a selection mask.
pub trait Filter {
pub trait Filter<By: ?Sized> {
/// The result type after performing the operation.
type Output;

Expand All @@ -22,5 +22,37 @@ pub trait Filter {
/// # Panics
///
/// If the length of the mask does not equal the length of the value being filtered.
fn filter(self, selection_mask: &Mask) -> Self::Output;
fn filter(self, selection: &By) -> Self::Output;
}

/// A view over a set of strictly sorted indices from a bit mask.
///
/// Unlike other indices, `MaskIndices` are always strict-sorted, meaning they are
/// always unique and monotonic.
///
/// You can treat a `MaskIndices` just like a `&[usize]` by iterating or indexing
/// into it just like you would a slice.
pub struct MaskIndices<'a>(&'a [usize]);

impl<'a> MaskIndices<'a> {
/// Create new indices from a slice of strict-sorted index values.
///
/// # Safety
///
/// The caller must ensure that the indices are strict-sorted, i.e. that they
/// are monotonic and unique.
///
/// Users of the `Indices` type assume this and failure to uphold this guarantee
/// can result in UB downstream.
pub unsafe fn new_unchecked(indices: &'a [usize]) -> Self {
Self(indices)
Comment on lines +47 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a debug_assert!(indices.is_sorted())

}
}

impl Deref for MaskIndices<'_> {
type Target = [usize];

fn deref(&self) -> &Self::Target {
self.0
}
}
Loading
Loading