Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 342 additions & 4 deletions arrow-select/src/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@

use crate::filter::{SlicesIterator, prep_null_mask_filter};
use arrow_array::cast::AsArray;
use arrow_array::types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type};
use arrow_array::types::{
BinaryType, BinaryViewType, ByteArrayType, ByteViewType, LargeBinaryType, LargeUtf8Type,
StringViewType, Utf8Type,
};
use arrow_array::*;
use arrow_buffer::{
BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, OffsetBufferBuilder,
ScalarBuffer,
ScalarBuffer, ToByteSlice,
};
use arrow_data::ArrayData;
use arrow_data::transform::MutableArrayData;
use arrow_data::{ArrayData, ByteView};
use arrow_schema::{ArrowError, DataType};
use std::fmt::{Debug, Formatter};
use std::hash::Hash;
Expand Down Expand Up @@ -284,7 +287,12 @@ impl ScalarZipper {
DataType::LargeBinary => {
Arc::new(BytesScalarImpl::<LargeBinaryType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
},
// TODO: Handle Utf8View https://github.com/apache/arrow-rs/issues/8724
DataType::Utf8View => {
Arc::new(ByteViewScalarImpl::<StringViewType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
},
DataType::BinaryView => {
Arc::new(ByteViewScalarImpl::<BinaryViewType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
},
_ => {
Arc::new(FallbackImpl::new(truthy, falsy)) as Arc<dyn ZipImpl>
},
Expand Down Expand Up @@ -657,6 +665,182 @@ fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer {
}
}

struct ByteViewScalarImpl<T: ByteViewType> {
truthy: Option<GenericByteViewArray<T>>,
falsy: Option<GenericByteViewArray<T>>,
phantom: PhantomData<T>,
}

impl<T: ByteViewType> ByteViewScalarImpl<T> {
fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self {
Self {
truthy: Self::get_value_from_scalar(truthy),
falsy: Self::get_value_from_scalar(falsy),
phantom: PhantomData,
}
}

fn get_value_from_scalar(scalar: &dyn Array) -> Option<GenericByteViewArray<T>> {
if scalar.is_null(0) {
None
} else {
Some(scalar.as_byte_view().clone())
}
}

fn get_scalar_buffers_and_nulls_for_all_values_null(
len: usize,
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
let mut mutable = MutableBuffer::with_capacity(0);
mutable.repeat_slice_n_times((0u128).to_byte_slice(), len);

(mutable.into(), vec![], Some(NullBuffer::new_null(len)))
}

fn get_scalar_buffers_and_nulls_for_single_non_nullable(
predicate: BooleanBuffer,
value: &GenericByteViewArray<T>,
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
let number_of_true = predicate.count_set_bits();
let number_of_values = predicate.len();

// Fast path for all nulls
if number_of_true == 0 {
// All values are null
return Self::get_scalar_buffers_and_nulls_for_all_values_null(number_of_values);
}
let view = value.views()[0].to_byte_slice();
let mut bytes = MutableBuffer::with_capacity(0);
bytes.repeat_slice_n_times(view, number_of_values);

let bytes = Buffer::from(bytes);

// If a value is true we need the TRUTHY and the null buffer will have 1 (meaning not null)
// If a value is false we need the FALSY and the null buffer will have 0 (meaning null)
let nulls = NullBuffer::new(predicate);
(bytes.into(), value.data_buffers().into(), Some(nulls))
}

fn get_scalar_buffers_and_nulls_non_nullable(
predicate: BooleanBuffer,
truthy: &GenericByteViewArray<T>,
falsy: &GenericByteViewArray<T>,
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
let true_count = predicate.count_set_bits();
let view_truthy = truthy.views()[0].to_byte_slice();
let mut buffers: Vec<Buffer> = truthy.data_buffers().to_vec();

// if falsy has non-inlined values in the buffer,
// include the buffers and recalculate the view,
// otherwise, we simply use the view.
let view_falsy = if falsy.total_buffer_bytes_used() > 0 {
let byte_view_falsy = ByteView::from(falsy.views()[0]);
let new_index_falsy_buffers = buffers.len() as u32;
buffers.extend(falsy.data_buffers().to_vec());
let byte_view_falsy = byte_view_falsy.with_buffer_index(new_index_falsy_buffers);
byte_view_falsy.as_u128()
} else {
falsy.views()[0]
};

let total_number_of_bytes = true_count * view_truthy.len()
+ (predicate.len() - true_count) * view_falsy.to_byte_slice().len();
let mut mutable = MutableBuffer::new(total_number_of_bytes);
let mut filled = 0;

SlicesIterator::from(&predicate).for_each(|(start, end)| {
if start > filled {
let false_repeat_count = start - filled;
mutable.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count);
}
let true_repeat_count = end - start;
mutable.repeat_slice_n_times(view_truthy, true_repeat_count);
filled = end;
});

if filled < predicate.len() {
let false_repeat_count = predicate.len() - filled;
mutable.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count);
}

let bytes = Buffer::from(mutable);

(
bytes.into(),
buffers,
Some(NullBuffer::new_valid(predicate.len())),
)
}

fn get_scalar_buffers_and_nulls_for_all_same_value(
length: usize,
value: &GenericByteViewArray<T>,
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
let (views, buffers, _) = value.clone().into_parts();
let mut mutable = MutableBuffer::with_capacity(0);
mutable.repeat_slice_n_times(views[0].to_byte_slice(), length);

let bytes = Buffer::from(mutable);

(bytes.into(), buffers, Some(NullBuffer::new_valid(length)))
}

fn create_output_on_non_nulls(
predicate: BooleanBuffer,
result_len: usize,
truthy: &GenericByteViewArray<T>,
falsy: &GenericByteViewArray<T>,
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
let true_count = predicate.count_set_bits();
match true_count {
0 => {
// all values are falsy
Self::get_scalar_buffers_and_nulls_for_all_same_value(result_len, falsy)
}
n if n == predicate.len() => {
// all values are truthy
Self::get_scalar_buffers_and_nulls_for_all_same_value(result_len, truthy)
}
_ => Self::get_scalar_buffers_and_nulls_non_nullable(predicate, truthy, falsy),
}
}
}

impl<T: ByteViewType> Debug for ByteViewScalarImpl<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ByteViewScalarImpl")
.field("truthy", &self.truthy)
.field("falsy", &self.falsy)
.finish()
}
}

impl<T: ByteViewType> ZipImpl for ByteViewScalarImpl<T> {
fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
let result_len = predicate.len();
// Nulls are treated as false
let predicate = maybe_prep_null_mask_filter(predicate);

let (views, buffers, nulls) = match (self.truthy.as_ref(), self.falsy.as_ref()) {
(Some(truthy), Some(falsy)) => {
Self::create_output_on_non_nulls(predicate, result_len, truthy, falsy)
}
(Some(truthy), None) => {
Self::get_scalar_buffers_and_nulls_for_single_non_nullable(predicate, truthy)
}

(None, Some(falsy)) => {
let predicate = predicate.not();
Self::get_scalar_buffers_and_nulls_for_single_non_nullable(predicate, falsy)
}
(None, None) => Self::get_scalar_buffers_and_nulls_for_all_values_null(result_len),
};

let result = unsafe { GenericByteViewArray::<T>::new_unchecked(views, buffers, nulls) };
Ok(Arc::new(result))
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -1222,4 +1406,158 @@ mod test {
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));

let mask = BooleanArray::from(vec![true, false, true, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("hello"),
Some("world"),
Some("hello"),
Some("world"),
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_binary_array_view() {
let scalar_truthy = Scalar::new(BinaryViewArray::from_iter_values(vec![b"hello"]));
let scalar_falsy = Scalar::new(BinaryViewArray::from_iter_values(vec![b"world"]));

let mask = BooleanArray::from(vec![true, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_byte_view();
let expected = BinaryViewArray::from_iter_values(vec![b"hello", b"world"]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_with_nulls() {
let scalar_truthy = Scalar::new(StringViewArray::from_iter_values(["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::new_null(1));

let mask = BooleanArray::from(vec![true, true, false, false, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
let expected = StringViewArray::from_iter(vec![
Some("hello"),
Some("hello"),
None,
None,
Some("hello"),
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_all_true_null() {
let scalar_truthy = Scalar::new(StringViewArray::new_null(1));
let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
let mask = BooleanArray::from(vec![true, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
let expected = StringViewArray::from_iter(vec![None::<String>, None]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_all_false_null() {
let scalar_truthy = Scalar::new(StringViewArray::new_null(1));
let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
let mask = BooleanArray::from(vec![false, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
let expected = StringViewArray::from_iter(vec![None::<String>, None]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_string_array_view_all_true() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));

let mask = BooleanArray::from(vec![true, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![Some("hello"), Some("hello")]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_string_array_view_all_false() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));

let mask = BooleanArray::from(vec![false, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![Some("world"), Some("world")]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_large_strings() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));

let mask = BooleanArray::from(vec![true, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("longer than 12 bytes"),
Some("another longer than 12 bytes"),
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_large_short_strings() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));

let mask = BooleanArray::from(vec![true, false, true, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("hello"),
Some("longer than 12 bytes"),
Some("hello"),
Some("longer than 12 bytes"),
]);
assert_eq!(actual, &expected);
}
#[test]
fn test_zip_kernel_scalar_strings_array_view_large_all_true() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));

let mask = BooleanArray::from(vec![true, true]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("longer than 12 bytes"),
Some("longer than 12 bytes"),
]);
assert_eq!(actual, &expected);
}

#[test]
fn test_zip_kernel_scalar_strings_array_view_large_all_false() {
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));

let mask = BooleanArray::from(vec![false, false]);
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
let actual = out.as_string_view();
let expected = StringViewArray::from(vec![
Some("another longer than 12 bytes"),
Some("another longer than 12 bytes"),
]);
assert_eq!(actual, &expected);
}
}
Loading