Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 99 additions & 29 deletions vortex-array/src/arrays/varbin/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::iter::Sum;
use std::ops::AddAssign;

use arrow_buffer::BooleanBufferBuilder;
use num_traits::PrimInt;
use vortex_buffer::{BufferMut, ByteBufferMut};
use vortex_dtype::{DType, NativePType, match_each_integer_ptype};
use vortex_error::{VortexResult, vortex_err, vortex_panic};
use vortex_error::{VortexExpect, VortexResult, vortex_panic};
use vortex_mask::Mask;

use crate::arrays::VarBinVTable;
use crate::arrays::varbin::VarBinArray;
use crate::arrays::varbin::builder::VarBinBuilder;
use crate::arrays::{PrimitiveArray, VarBinVTable};
use crate::compute::{TakeKernel, TakeKernelAdapter};
use crate::validity::Validity;
use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};

impl TakeKernel for VarBinVTable {
Expand Down Expand Up @@ -40,7 +43,7 @@ impl TakeKernel for VarBinVTable {

register_kernel!(TakeKernelAdapter(VarBinVTable).lift());

fn take<I: NativePType, O: NativePType + PrimInt + Sum>(
fn take<I: NativePType, O: NativePType + PrimInt + Sum + AddAssign>(
dtype: DType,
offsets: &[O],
data: &[u8],
Expand All @@ -59,55 +62,122 @@ fn take<I: NativePType, O: NativePType + PrimInt + Sum>(
));
}

let mut builder = VarBinBuilder::<u32>::with_capacity(indices.len());
let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
new_offsets.push(O::zero());
let mut current_offset = O::zero();

for &idx in indices {
let idx = idx
.to_usize()
.ok_or_else(|| vortex_err!("Failed to convert index to usize: {}", idx))?;
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
let start = offsets[idx];
let stop = offsets[idx + 1];
current_offset += stop - start;
new_offsets.push(current_offset);
}

let mut new_data = ByteBufferMut::with_capacity(
current_offset
.to_usize()
.vortex_expect("Failed to cast max offset to usize"),
);

for idx in indices {
let idx = idx
.to_usize()
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", idx));
let start = offsets[idx]
.to_usize()
.ok_or_else(|| vortex_err!("Failed to convert offset to usize: {}", offsets[idx]))?;
let stop = offsets[idx + 1].to_usize().ok_or_else(|| {
vortex_err!("Failed to convert offset to usize: {}", offsets[idx + 1])
})?;
builder.append_value(&data[start..stop]);
.vortex_expect("Failed to cast max offset to usize");
let stop = offsets[idx + 1]
.to_usize()
.vortex_expect("Failed to cast max offset to usize");
new_data.extend_from_slice(&data[start..stop]);
}

let array_validity = Validity::from(dtype.nullability());

// Safety:
// All variants of VarBinArray are satisfied here.
unsafe {
Ok(VarBinArray::new_unchecked(
PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
new_data.freeze(),
dtype,
array_validity,
))
}
Ok(builder.finish(dtype))
}

fn take_nullable<I: NativePType, O: NativePType + PrimInt>(
fn take_nullable<I: NativePType, O: NativePType + PrimInt + Sum + AddAssign>(
dtype: DType,
offsets: &[O],
data: &[u8],
indices: &[I],
data_validity: Mask,
indices_validity: Mask,
) -> VarBinArray {
let mut builder = VarBinBuilder::<u32>::with_capacity(indices.len());
let mut new_offsets = BufferMut::with_capacity(indices.len() + 1);
new_offsets.push(O::zero());
let mut current_offset = O::zero();

let mut validity_buffer = BooleanBufferBuilder::new(indices.len());

// Convert indices once and store valid ones with their positions
let mut valid_indices = Vec::with_capacity(indices.len());

// First pass: calculate offsets and validity
for (idx, data_idx) in indices.iter().enumerate() {
if !indices_validity.value(idx) {
builder.append_null();
validity_buffer.append(false);
new_offsets.push(current_offset);
continue;
}
let data_idx = data_idx
let data_idx_usize = data_idx
.to_usize()
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));
if data_validity.value(data_idx) {
let start = offsets[data_idx].to_usize().unwrap_or_else(|| {
vortex_panic!("Failed to convert offset to usize: {}", offsets[data_idx])
});
let stop = offsets[data_idx + 1].to_usize().unwrap_or_else(|| {
vortex_panic!(
"Failed to convert offset to usize: {}",
offsets[data_idx + 1]
)
});
builder.append_value(&data[start..stop]);
if data_validity.value(data_idx_usize) {
validity_buffer.append(true);
let start = offsets[data_idx_usize];
let stop = offsets[data_idx_usize + 1];
current_offset += stop - start;
new_offsets.push(current_offset);
valid_indices.push(data_idx_usize);
} else {
builder.append_null();
validity_buffer.append(false);
new_offsets.push(current_offset);
}
}
builder.finish(dtype)

let mut new_data = ByteBufferMut::with_capacity(
current_offset
.to_usize()
.vortex_expect("Failed to cast max offset to usize"),
);

// Second pass: copy data for valid indices only
for data_idx in valid_indices {
let start = offsets[data_idx]
.to_usize()
.vortex_expect("Failed to cast max offset to usize");
let stop = offsets[data_idx + 1]
.to_usize()
.vortex_expect("Failed to cast max offset to usize");
new_data.extend_from_slice(&data[start..stop]);
}

let array_validity = Validity::from(validity_buffer.finish());

// Safety:
// All variants of VarBinArray are satisfied here.
unsafe {
VarBinArray::new_unchecked(
PrimitiveArray::new(new_offsets.freeze(), Validity::NonNullable).into_array(),
new_data.freeze(),
dtype,
array_validity,
)
}
}

#[cfg(test)]
Expand Down
Loading