Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 4 additions & 23 deletions encodings/alp/src/alp/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@ use vortex_array::ArrayRef;
use vortex_array::Canonical;
use vortex_array::DeserializeMetadata;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::Precision;
use vortex_array::ProstMetadata;
use vortex_array::SerializeMetadata;
use vortex_array::arrays::SliceVTable;
use vortex_array::buffer::BufferHandle;
use vortex_array::patches::Patches;
use vortex_array::patches::PatchesMetadata;
Expand All @@ -43,6 +41,7 @@ use vortex_error::vortex_err;
use crate::ALPFloat;
use crate::alp::Exponents;
use crate::alp::decompress::execute_decompress;
use crate::alp::rules::PARENT_KERNELS;

vtable!(ALP);

Expand Down Expand Up @@ -177,28 +176,10 @@ impl VTable for ALPVTable {
fn execute_parent(
array: &Self::Array,
parent: &ArrayRef,
_child_idx: usize,
_ctx: &mut ExecutionCtx,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
// CPU-only: if parent is SliceArray, perform slicing of the buffer and any patches
// Note that this triggers compute (binary searching Patches) which we cannot do when the
// buffers live in GPU memory.
if let Some(slice_array) = parent.as_opt::<SliceVTable>() {
let range = slice_array.slice_range().clone();
let sliced_alp = ALPArray::new(
array.encoded().slice(range.clone())?,
array.exponents(),
array
.patches()
.map(|p| p.slice(range))
.transpose()?
.flatten(),
)
.into_array();
return Ok(Some(sliced_alp));
}

Ok(None)
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
}
}

Expand Down
29 changes: 16 additions & 13 deletions encodings/alp/src/alp/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::ArrayRef;
use vortex_array::compute::FilterKernel;
use vortex_array::compute::FilterKernelAdapter;
use vortex_array::register_kernel;
use vortex_array::ExecutionCtx;
use vortex_array::arrays::FilterKernel;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::ALPArray;
use crate::ALPVTable;

impl FilterKernel for ALPVTable {
fn filter(&self, array: &ALPArray, mask: &Mask) -> VortexResult<ArrayRef> {
fn filter(
array: &ALPArray,
mask: &Mask,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let patches = array
.patches()
.map(|p| p.filter(mask))
Expand All @@ -21,19 +24,19 @@ impl FilterKernel for ALPVTable {

// SAFETY: filtering the values does not change correctness
unsafe {
Ok(ALPArray::new_unchecked(
array.encoded().filter(mask.clone())?,
array.exponents(),
patches,
array.dtype().clone(),
)
.to_array())
Ok(Some(
ALPArray::new_unchecked(
array.encoded().filter(mask.clone())?,
array.exponents(),
patches,
array.dtype().clone(),
)
.to_array(),
))
}
}
}

register_kernel!(FilterKernelAdapter(ALPVTable).lift());

#[cfg(test)]
mod test {
use rstest::rstest;
Expand Down
1 change: 1 addition & 0 deletions encodings/alp/src/alp/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod compare;
mod filter;
mod mask;
mod nan_count;
mod slice;
mod take;

#[cfg(test)]
Expand Down
33 changes: 33 additions & 0 deletions encodings/alp/src/alp/compute/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::ops::Range;

use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::SliceKernel;
use vortex_error::VortexResult;

use crate::ALPArray;
use crate::ALPVTable;

impl SliceKernel for ALPVTable {
fn slice(
array: &Self::Array,
range: Range<usize>,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let sliced_alp = ALPArray::new(
array.encoded().slice(range.clone())?,
array.exponents(),
array
.patches()
.map(|p| p.slice(range))
.transpose()?
.flatten(),
)
.into_array();
Ok(Some(sliced_alp))
}
}
1 change: 1 addition & 0 deletions encodings/alp/src/alp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod compress;
mod compute;
mod decompress;
mod ops;
mod rules;

#[cfg(test)]
mod tests {
Expand Down
13 changes: 13 additions & 0 deletions encodings/alp/src/alp/rules.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::arrays::FilterExecuteAdaptor;
use vortex_array::arrays::SliceExecuteAdaptor;
use vortex_array::kernel::ParentKernelSet;

use crate::ALPVTable;

pub(super) const PARENT_KERNELS: ParentKernelSet<ALPVTable> = ParentKernelSet::new(&[
ParentKernelSet::lift(&FilterExecuteAdaptor(ALPVTable)),
ParentKernelSet::lift(&SliceExecuteAdaptor(ALPVTable)),
]);
33 changes: 18 additions & 15 deletions encodings/alp/src/alp_rd/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,41 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::compute::FilterKernel;
use vortex_array::compute::FilterKernelAdapter;
use vortex_array::register_kernel;
use vortex_array::arrays::FilterKernel;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::ALPRDArray;
use crate::ALPRDVTable;

impl FilterKernel for ALPRDVTable {
fn filter(&self, array: &ALPRDArray, mask: &Mask) -> VortexResult<ArrayRef> {
fn filter(
array: &ALPRDArray,
mask: &Mask,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let left_parts_exceptions = array
.left_parts_patches()
.map(|patches| patches.filter(mask))
.transpose()?
.flatten();

Ok(ALPRDArray::try_new(
array.dtype().clone(),
array.left_parts().filter(mask.clone())?,
array.left_parts_dictionary().clone(),
array.right_parts().filter(mask.clone())?,
array.right_bit_width(),
left_parts_exceptions,
)?
.into_array())
Ok(Some(
ALPRDArray::try_new(
array.dtype().clone(),
array.left_parts().filter(mask.clone())?,
array.left_parts_dictionary().clone(),
array.right_parts().filter(mask.clone())?,
array.right_bit_width(),
left_parts_exceptions,
)?
.into_array(),
))
}
}

register_kernel!(FilterKernelAdapter(ALPRDVTable).lift());

#[cfg(test)]
mod test {
use rstest::rstest;
Expand Down
7 changes: 5 additions & 2 deletions encodings/alp/src/alp_rd/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::arrays::FilterExecuteAdaptor;
use vortex_array::arrays::SliceExecuteAdaptor;
use vortex_array::kernel::ParentKernelSet;

use crate::alp_rd::ALPRDVTable;

pub(crate) static PARENT_KERNELS: ParentKernelSet<ALPRDVTable> =
ParentKernelSet::new(&[ParentKernelSet::lift(&SliceExecuteAdaptor(ALPRDVTable))]);
pub(crate) static PARENT_KERNELS: ParentKernelSet<ALPRDVTable> = ParentKernelSet::new(&[
ParentKernelSet::lift(&SliceExecuteAdaptor(ALPRDVTable)),
ParentKernelSet::lift(&FilterExecuteAdaptor(ALPRDVTable)),
]);
25 changes: 12 additions & 13 deletions encodings/datetime-parts/src/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,26 @@

use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::compute::FilterKernel;
use vortex_array::compute::FilterKernelAdapter;
use vortex_array::register_kernel;
use vortex_array::arrays::FilterReduce;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::DateTimePartsArray;
use crate::DateTimePartsVTable;

impl FilterKernel for DateTimePartsVTable {
fn filter(&self, array: &DateTimePartsArray, mask: &Mask) -> VortexResult<ArrayRef> {
Ok(DateTimePartsArray::try_new(
array.dtype().clone(),
array.days().filter(mask.clone())?,
array.seconds().filter(mask.clone())?,
array.subseconds().filter(mask.clone())?,
)?
.into_array())
impl FilterReduce for DateTimePartsVTable {
fn filter(array: &DateTimePartsArray, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
Ok(Some(
DateTimePartsArray::try_new(
array.dtype().clone(),
array.days().filter(mask.clone())?,
array.seconds().filter(mask.clone())?,
array.subseconds().filter(mask.clone())?,
)?
.into_array(),
))
}
}
register_kernel!(FilterKernelAdapter(DateTimePartsVTable).lift());

#[cfg(test)]
mod test {
Expand Down
2 changes: 2 additions & 0 deletions encodings/datetime-parts/src/compute/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use vortex_array::arrays::AnyScalarFn;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::ConstantVTable;
use vortex_array::arrays::FilterArray;
use vortex_array::arrays::FilterReduceAdaptor;
use vortex_array::arrays::FilterVTable;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::SliceReduceAdaptor;
Expand All @@ -29,6 +30,7 @@ use crate::timestamp;
pub(crate) const PARENT_RULES: ParentRuleSet<DateTimePartsVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&DTPFilterPushDownRule),
ParentRuleSet::lift(&DTPComparisonPushDownRule),
ParentRuleSet::lift(&FilterReduceAdaptor(DateTimePartsVTable)),
ParentRuleSet::lift(&SliceReduceAdaptor(DateTimePartsVTable)),
]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,20 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::ArrayRef;
use vortex_array::compute::FilterKernel;
use vortex_array::compute::FilterKernelAdapter;
use vortex_array::register_kernel;
use vortex_array::arrays::FilterReduce;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::DecimalBytePartsArray;
use crate::DecimalBytePartsVTable;

impl FilterKernel for DecimalBytePartsVTable {
fn filter(&self, array: &Self::Array, mask: &Mask) -> VortexResult<ArrayRef> {
impl FilterReduce for DecimalBytePartsVTable {
fn filter(array: &DecimalBytePartsArray, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
DecimalBytePartsArray::try_new(array.msp.filter(mask.clone())?, *array.decimal_dtype())
.map(|d| d.to_array())
.map(|d| Some(d.to_array()))
}
}

register_kernel!(FilterKernelAdapter(DecimalBytePartsVTable).lift());

#[cfg(test)]
mod test {
use vortex_array::IntoArray;
Expand Down
2 changes: 2 additions & 0 deletions encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use vortex_array::Array;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::arrays::FilterArray;
use vortex_array::arrays::FilterReduceAdaptor;
use vortex_array::arrays::FilterVTable;
use vortex_array::arrays::SliceReduceAdaptor;
use vortex_array::optimizer::rules::ArrayParentReduceRule;
Expand All @@ -16,6 +17,7 @@ use crate::DecimalBytePartsVTable;

pub(super) const PARENT_RULES: ParentRuleSet<DecimalBytePartsVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&DecimalBytePartsFilterPushDownRule),
ParentRuleSet::lift(&FilterReduceAdaptor(DecimalBytePartsVTable)),
ParentRuleSet::lift(&SliceReduceAdaptor(DecimalBytePartsVTable)),
]);

Expand Down
Loading
Loading