Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/primitive/compute/take/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::primitive::PrimitiveArray;
use crate::arrays::primitive::compute::take::TakeImpl;
use crate::arrays::primitive::compute::take::take_primitive_scalar;
use crate::arrays::primitive::compute::take::scalar::take_primitive_scalar;
use crate::validity::Validity;

#[allow(unused)]
Expand Down
37 changes: 3 additions & 34 deletions vortex-array/src/arrays/primitive/compute/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,16 @@ mod avx2;

#[cfg(vortex_nightly)]
mod portable;
mod scalar;

use std::sync::LazyLock;

use vortex_buffer::Buffer;
use vortex_dtype::DType;
use vortex_dtype::IntegerPType;
use vortex_dtype::NativePType;
use vortex_dtype::match_each_integer_ptype;
use vortex_dtype::match_each_native_ptype;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;

use crate::Array;
use crate::ArrayRef;
use crate::IntoArray;
use crate::ToCanonical;
use crate::arrays::PrimitiveVTable;
use crate::arrays::primitive::PrimitiveArray;
Expand All @@ -44,11 +39,11 @@ static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(||
if is_x86_feature_detected!("avx2") {
&avx2::TakeKernelAVX2
} else {
&TakeKernelScalar
&scalar::TakeKernelScalar
}
} else {
// stable all other platforms: scalar kernel
&TakeKernelScalar
&scalar::TakeKernelScalar
}
}
});
Expand All @@ -62,25 +57,6 @@ trait TakeImpl: Send + Sync {
) -> VortexResult<ArrayRef>;
}

#[allow(unused)]
struct TakeKernelScalar;

impl TakeImpl for TakeKernelScalar {
fn take(
&self,
array: &PrimitiveArray,
indices: &PrimitiveArray,
validity: Validity,
) -> VortexResult<ArrayRef> {
match_each_native_ptype!(array.ptype(), |T| {
match_each_integer_ptype!(indices.ptype(), |I| {
let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
Ok(PrimitiveArray::new(values, validity).into_array())
})
})
}
}

impl TakeKernel for PrimitiveVTable {
fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
let DType::Primitive(ptype, null) = indices.dtype() else {
Expand All @@ -102,13 +78,6 @@ impl TakeKernel for PrimitiveVTable {

register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());

// Compiler may see this as unused based on enabled features
#[allow(unused)]
#[inline(always)]
fn take_primitive_scalar<T: NativePType, I: IntegerPType>(array: &[T], indices: &[I]) -> Buffer<T> {
indices.iter().map(|idx| array[idx.as_()]).collect()
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[cfg(test)]
mod test {
Expand Down
119 changes: 119 additions & 0 deletions vortex-array/src/arrays/primitive/compute/take/scalar.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_buffer::Buffer;
use vortex_dtype::IntegerPType;
use vortex_dtype::NativePType;
use vortex_dtype::match_each_integer_ptype;
use vortex_dtype::match_each_native_ptype;
use vortex_error::VortexResult;

use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::primitive::compute::take::TakeImpl;
use crate::validity::Validity;
use crate::vtable::ValidityHelper;

#[allow(unused)]
pub(super) struct TakeKernelScalar;

impl TakeImpl for TakeKernelScalar {
#[allow(clippy::cognitive_complexity)]
fn take(
&self,
array: &PrimitiveArray,
indices: &PrimitiveArray,
validity: Validity,
) -> VortexResult<ArrayRef> {
match_each_native_ptype!(array.ptype(), |T| {
match_each_integer_ptype!(indices.ptype(), |I| {
let indices_slice = indices.as_slice::<I>();
let indices_validity = indices.validity();
let values = if indices_validity.all_valid(indices_slice.len()) {
// Fast path: indices have no nulls, safe to index directly
take_primitive_scalar(array.as_slice::<T>(), indices_slice)
} else {
// Slow path: indices may have nulls with garbage values
take_primitive_scalar_with_nulls(
array.as_slice::<T>(),
indices_slice,
indices_validity,
)
};
Ok(PrimitiveArray::new(values, validity).into_array())
})
})
}
}

// Compiler may see this as unused based on enabled features
#[allow(unused)]
#[inline(always)]
pub(super) fn take_primitive_scalar<T: NativePType, I: IntegerPType>(
array: &[T],
indices: &[I],
) -> Buffer<T> {
indices.iter().map(|idx| array[idx.as_()]).collect()
}

/// Slow path for take when indices may contain nulls with garbage values.
/// Uses 0 as a safe index for null positions (the value will be masked out by validity).
#[allow(unused)]
#[inline(always)]
fn take_primitive_scalar_with_nulls<T: NativePType, I: IntegerPType>(
array: &[T],
indices: &[I],
validity: &Validity,
) -> Buffer<T> {
indices
.iter()
.enumerate()
.map(|(i, idx)| {
if validity.is_valid(i) {
array[idx.as_()]
} else {
T::zero()
}
})
.collect()
}

#[cfg(test)]
mod tests {
use vortex_buffer::buffer;

use crate::IntoArray;
use crate::ToCanonical;
use crate::arrays::PrimitiveArray;
use crate::arrays::primitive::compute::take::TakeImpl;
use crate::arrays::primitive::compute::take::scalar::TakeKernelScalar;
use crate::validity::Validity;

#[test]
fn test_scalar_basic() {
let values = buffer![1, 2, 3, 4, 5].into_array().to_primitive();
let indices = buffer![0, 1, 1, 2, 2, 3, 4].into_array().to_primitive();

let result = TakeKernelScalar
.take(&values, &indices, Validity::NonNullable)
.unwrap()
.to_primitive();
assert_eq!(result.as_slice::<i32>(), &[1, 2, 2, 3, 3, 4, 5]);
}

#[test]
fn test_scalar_with_nulls() {
let values = buffer![1, 2, 3, 4, 5].into_array().to_primitive();
let validity = Validity::from_iter([true, false, true, true, true]);
let indices = PrimitiveArray::new(buffer![0, 100, 2, 3, 4], validity.clone());

let result = TakeKernelScalar
.take(&values, &indices, validity.clone())
.unwrap()
.to_primitive();

assert_eq!(result.as_slice::<i32>(), &[1, 0, 3, 4, 5]);
assert_eq!(result.validity, validity);
}
}
3 changes: 2 additions & 1 deletion vortex-array/src/patches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,8 @@ mod test {
let primitive_values = taken.values().to_primitive();
let primitive_indices = taken.indices().to_primitive();
assert_eq!(taken.array_len(), 2);
assert_eq!(primitive_values.as_slice::<i32>(), [44, 33]);
assert_eq!(primitive_values.scalar_at(0), Some(44i32).into());
assert_eq!(primitive_values.scalar_at(1), Option::<i32>::None.into());
assert_eq!(primitive_indices.as_slice::<u64>(), [0, 1]);

assert_eq!(
Expand Down
Loading