Skip to content

Commit b3eec28

Browse files
committed
fix: avoid oob take for primitive
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent fe4c81b commit b3eec28

File tree

2 files changed

+118
-33
lines changed

2 files changed

+118
-33
lines changed

vortex-array/src/arrays/primitive/compute/take/mod.rs

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,16 @@ mod avx2;
66

77
#[cfg(vortex_nightly)]
88
mod portable;
9+
mod scalar;
910

1011
use std::sync::LazyLock;
1112

12-
use vortex_buffer::Buffer;
1313
use vortex_dtype::DType;
14-
use vortex_dtype::IntegerPType;
15-
use vortex_dtype::NativePType;
16-
use vortex_dtype::match_each_integer_ptype;
17-
use vortex_dtype::match_each_native_ptype;
1814
use vortex_error::VortexResult;
1915
use vortex_error::vortex_bail;
2016

2117
use crate::Array;
2218
use crate::ArrayRef;
23-
use crate::IntoArray;
2419
use crate::ToCanonical;
2520
use crate::arrays::PrimitiveVTable;
2621
use crate::arrays::primitive::PrimitiveArray;
@@ -48,7 +43,7 @@ static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(||
4843
}
4944
} else {
5045
// stable all other platforms: scalar kernel
51-
&TakeKernelScalar
46+
&scalar::TakeKernelScalar
5247
}
5348
}
5449
});
@@ -62,25 +57,6 @@ trait TakeImpl: Send + Sync {
6257
) -> VortexResult<ArrayRef>;
6358
}
6459

65-
#[allow(unused)]
66-
struct TakeKernelScalar;
67-
68-
impl TakeImpl for TakeKernelScalar {
69-
fn take(
70-
&self,
71-
array: &PrimitiveArray,
72-
indices: &PrimitiveArray,
73-
validity: Validity,
74-
) -> VortexResult<ArrayRef> {
75-
match_each_native_ptype!(array.ptype(), |T| {
76-
match_each_integer_ptype!(indices.ptype(), |I| {
77-
let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
78-
Ok(PrimitiveArray::new(values, validity).into_array())
79-
})
80-
})
81-
}
82-
}
83-
8460
impl TakeKernel for PrimitiveVTable {
8561
fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
8662
let DType::Primitive(ptype, null) = indices.dtype() else {
@@ -102,13 +78,6 @@ impl TakeKernel for PrimitiveVTable {
10278

10379
register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
10480

105-
// Compiler may see this as unused based on enabled features
106-
#[allow(unused)]
107-
#[inline(always)]
108-
fn take_primitive_scalar<T: NativePType, I: IntegerPType>(array: &[T], indices: &[I]) -> Buffer<T> {
109-
indices.iter().map(|idx| array[idx.as_()]).collect()
110-
}
111-
11281
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
11382
#[cfg(test)]
11483
mod test {
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_buffer::Buffer;
5+
use vortex_dtype::IntegerPType;
6+
use vortex_dtype::NativePType;
7+
use vortex_dtype::match_each_integer_ptype;
8+
use vortex_dtype::match_each_native_ptype;
9+
use vortex_error::VortexResult;
10+
11+
use crate::ArrayRef;
12+
use crate::IntoArray;
13+
use crate::arrays::PrimitiveArray;
14+
use crate::arrays::primitive::compute::take::TakeImpl;
15+
use crate::validity::Validity;
16+
use crate::vtable::ValidityHelper;
17+
18+
#[allow(unused)]
19+
pub(super) struct TakeKernelScalar;
20+
21+
impl TakeImpl for TakeKernelScalar {
22+
#[allow(clippy::cognitive_complexity)]
23+
fn take(
24+
&self,
25+
array: &PrimitiveArray,
26+
indices: &PrimitiveArray,
27+
validity: Validity,
28+
) -> VortexResult<ArrayRef> {
29+
match_each_native_ptype!(array.ptype(), |T| {
30+
match_each_integer_ptype!(indices.ptype(), |I| {
31+
let indices_slice = indices.as_slice::<I>();
32+
let indices_validity = indices.validity();
33+
let values = if indices_validity.all_valid(indices_slice.len()) {
34+
// Fast path: indices have no nulls, safe to index directly
35+
take_primitive_scalar(array.as_slice::<T>(), indices_slice)
36+
} else {
37+
// Slow path: indices may have nulls with garbage values
38+
take_primitive_scalar_with_nulls(
39+
array.as_slice::<T>(),
40+
indices_slice,
41+
indices_validity,
42+
)
43+
};
44+
Ok(PrimitiveArray::new(values, validity).into_array())
45+
})
46+
})
47+
}
48+
}
49+
50+
// Compiler may see this as unused based on enabled features
51+
#[allow(unused)]
52+
#[inline(always)]
53+
fn take_primitive_scalar<T: NativePType, I: IntegerPType>(array: &[T], indices: &[I]) -> Buffer<T> {
54+
indices.iter().map(|idx| array[idx.as_()]).collect()
55+
}
56+
57+
/// Slow path for take when indices may contain nulls with garbage values.
58+
/// Uses 0 as a safe index for null positions (the value will be masked out by validity).
59+
#[allow(unused)]
60+
#[inline(always)]
61+
fn take_primitive_scalar_with_nulls<T: NativePType, I: IntegerPType>(
62+
array: &[T],
63+
indices: &[I],
64+
validity: &Validity,
65+
) -> Buffer<T> {
66+
indices
67+
.iter()
68+
.enumerate()
69+
.map(|(i, idx)| {
70+
if validity.is_valid(i) {
71+
array[idx.as_()]
72+
} else {
73+
T::zero()
74+
}
75+
})
76+
.collect()
77+
}
78+
79+
#[cfg(test)]
80+
mod tests {
81+
use vortex_buffer::buffer;
82+
83+
use crate::IntoArray;
84+
use crate::ToCanonical;
85+
use crate::arrays::PrimitiveArray;
86+
use crate::arrays::primitive::compute::take::TakeImpl;
87+
use crate::arrays::primitive::compute::take::scalar::TakeKernelScalar;
88+
use crate::validity::Validity;
89+
90+
#[test]
91+
fn test_scalar_basic() {
92+
let values = buffer![1, 2, 3, 4, 5].into_array().to_primitive();
93+
let indices = buffer![0, 1, 1, 2, 2, 3, 4].into_array().to_primitive();
94+
95+
let result = TakeKernelScalar
96+
.take(&values, &indices, Validity::NonNullable)
97+
.unwrap()
98+
.to_primitive();
99+
assert_eq!(result.as_slice::<i32>(), &[1, 2, 2, 3, 3, 4, 5]);
100+
}
101+
102+
#[test]
103+
fn test_scalar_with_nulls() {
104+
let values = buffer![1, 2, 3, 4, 5].into_array().to_primitive();
105+
let validity = Validity::from_iter([true, false, true, true, true]);
106+
let indices = PrimitiveArray::new(buffer![0, 100, 2, 3, 4], validity.clone());
107+
108+
let result = TakeKernelScalar
109+
.take(&values, &indices, validity.clone())
110+
.unwrap()
111+
.to_primitive();
112+
113+
assert_eq!(result.as_slice::<i32>(), &[1, 0, 3, 4, 5]);
114+
assert_eq!(result.validity, validity);
115+
}
116+
}

0 commit comments

Comments
 (0)