Skip to content

Commit 70809ca

Browse files
committed
Add custom implemenation for zip for string-views scalars
1 parent 282cd50 commit 70809ca

File tree

1 file changed

+327
-4
lines changed

1 file changed

+327
-4
lines changed

arrow-select/src/zip.rs

Lines changed: 327 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919
2020
use crate::filter::{SlicesIterator, prep_null_mask_filter};
2121
use arrow_array::cast::AsArray;
22-
use arrow_array::types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type};
22+
use arrow_array::types::{
23+
BinaryType, ByteArrayType, ByteViewType, LargeBinaryType, LargeUtf8Type, StringViewType,
24+
Utf8Type,
25+
};
2326
use arrow_array::*;
2427
use arrow_buffer::{
2528
BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, OffsetBufferBuilder,
26-
ScalarBuffer,
29+
ScalarBuffer, ToByteSlice,
2730
};
28-
use arrow_data::ArrayData;
2931
use arrow_data::transform::MutableArrayData;
32+
use arrow_data::{ArrayData, ByteView};
3033
use arrow_schema::{ArrowError, DataType};
3134
use std::fmt::{Debug, Formatter};
3235
use std::hash::Hash;
@@ -284,7 +287,9 @@ impl ScalarZipper {
284287
DataType::LargeBinary => {
285288
Arc::new(BytesScalarImpl::<LargeBinaryType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
286289
},
287-
// TODO: Handle Utf8View https://github.com/apache/arrow-rs/issues/8724
290+
DataType::Utf8View => {
291+
Arc::new(ByteViewScalarImpl::<StringViewType>::new(truthy, falsy)) as Arc<dyn ZipImpl>
292+
},
288293
_ => {
289294
Arc::new(FallbackImpl::new(truthy, falsy)) as Arc<dyn ZipImpl>
290295
},
@@ -657,6 +662,182 @@ fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer {
657662
}
658663
}
659664

665+
struct ByteViewScalarImpl<T: ByteViewType> {
666+
truthy: Option<StringViewArray>,
667+
falsy: Option<StringViewArray>,
668+
phantom: PhantomData<T>,
669+
}
670+
671+
impl<T: ByteViewType> ByteViewScalarImpl<T> {
672+
fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self {
673+
Self {
674+
truthy: Self::get_value_from_scalar(truthy),
675+
falsy: Self::get_value_from_scalar(falsy),
676+
phantom: PhantomData,
677+
}
678+
}
679+
680+
fn get_value_from_scalar(scalar: &dyn Array) -> Option<StringViewArray> {
681+
if scalar.is_null(0) {
682+
None
683+
} else {
684+
Some(scalar.as_string_view().clone())
685+
}
686+
}
687+
688+
fn get_scalar_buffers_and_nulls_for_all_values_null(
689+
len: usize,
690+
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
691+
let mut mutable = MutableBuffer::with_capacity(0);
692+
mutable.repeat_slice_n_times((0u128).to_byte_slice(), len);
693+
694+
(mutable.into(), vec![], Some(NullBuffer::new_null(len)))
695+
}
696+
697+
fn get_scalar_buffers_and_nulls_for_single_non_nullable(
698+
predicate: BooleanBuffer,
699+
value: &StringViewArray,
700+
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
701+
let number_of_true = predicate.count_set_bits();
702+
let number_of_values = predicate.len();
703+
704+
// Fast path for all nulls
705+
if number_of_true == 0 {
706+
// All values are null
707+
return Self::get_scalar_buffers_and_nulls_for_all_values_null(number_of_values);
708+
}
709+
let view = value.views()[0].to_byte_slice();
710+
let mut bytes = MutableBuffer::with_capacity(0);
711+
bytes.repeat_slice_n_times(view, number_of_values);
712+
713+
let bytes = Buffer::from(bytes);
714+
715+
// If a value is true we need the TRUTHY and the null buffer will have 1 (meaning not null)
716+
// If a value is false we need the FALSY and the null buffer will have 0 (meaning null)
717+
let nulls = NullBuffer::new(predicate);
718+
(bytes.into(), value.data_buffers().into(), Some(nulls))
719+
}
720+
721+
fn get_scalar_buffers_and_nulls_non_nullable(
722+
predicate: BooleanBuffer,
723+
truthy: &StringViewArray,
724+
falsy: &StringViewArray,
725+
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
726+
let true_count = predicate.count_set_bits();
727+
let view_truthy = truthy.views()[0].to_byte_slice();
728+
let mut buffers: Vec<Buffer> = truthy.data_buffers().to_vec();
729+
730+
// if falsy has non-inlined values in the buffer,
731+
// include the buffers and recalculate the view,
732+
// otherwise, we simply use the view.
733+
let view_falsy = if falsy.total_buffer_bytes_used() > 0 {
734+
let byte_view_falsy = ByteView::from(falsy.views()[0]);
735+
let new_index_falsy_buffers = buffers.len() as u32;
736+
buffers.extend(falsy.data_buffers().to_vec());
737+
let byte_view_falsy = byte_view_falsy.with_buffer_index(new_index_falsy_buffers);
738+
byte_view_falsy.as_u128()
739+
} else {
740+
falsy.views()[0]
741+
};
742+
743+
let total_number_of_bytes = true_count * view_truthy.len()
744+
+ (predicate.len() - true_count) * view_falsy.to_byte_slice().len();
745+
let mut mutable = MutableBuffer::new(total_number_of_bytes);
746+
let mut filled = 0;
747+
748+
SlicesIterator::from(&predicate).for_each(|(start, end)| {
749+
if start > filled {
750+
let false_repeat_count = start - filled;
751+
mutable.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count);
752+
}
753+
let true_repeat_count = end - start;
754+
mutable.repeat_slice_n_times(view_truthy, true_repeat_count);
755+
filled = end;
756+
});
757+
758+
if filled < predicate.len() {
759+
let false_repeat_count = predicate.len() - filled;
760+
mutable.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count);
761+
}
762+
763+
let bytes = Buffer::from(mutable);
764+
765+
(
766+
bytes.into(),
767+
buffers,
768+
Some(NullBuffer::new_valid(predicate.len())),
769+
)
770+
}
771+
772+
fn get_scalar_buffers_and_nulls_for_all_same_value(
773+
length: usize,
774+
value: &StringViewArray,
775+
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
776+
let (views, buffers, _) = value.clone().into_parts();
777+
let mut mutable = MutableBuffer::with_capacity(0);
778+
mutable.repeat_slice_n_times(views[0].to_byte_slice(), length);
779+
780+
let bytes = Buffer::from(mutable);
781+
782+
(bytes.into(), buffers, Some(NullBuffer::new_valid(length)))
783+
}
784+
785+
fn create_output_on_non_nulls(
786+
predicate: BooleanBuffer,
787+
result_len: usize,
788+
truthy: &StringViewArray,
789+
falsy: &StringViewArray,
790+
) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) {
791+
let true_count = predicate.count_set_bits();
792+
match true_count {
793+
0 => {
794+
// all values are falsy
795+
Self::get_scalar_buffers_and_nulls_for_all_same_value(result_len, falsy)
796+
}
797+
n if n == predicate.len() => {
798+
// all values are truthy
799+
Self::get_scalar_buffers_and_nulls_for_all_same_value(result_len, truthy)
800+
}
801+
_ => Self::get_scalar_buffers_and_nulls_non_nullable(predicate, truthy, falsy),
802+
}
803+
}
804+
}
805+
806+
impl<T: ByteViewType> Debug for ByteViewScalarImpl<T> {
807+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
808+
f.debug_struct("ByteViewScalarImpl")
809+
.field("truthy", &self.truthy)
810+
.field("falsy", &self.falsy)
811+
.finish()
812+
}
813+
}
814+
815+
impl<T: ByteViewType> ZipImpl for ByteViewScalarImpl<T> {
816+
fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
817+
let result_len = predicate.len();
818+
// Nulls are treated as false
819+
let predicate = maybe_prep_null_mask_filter(predicate);
820+
821+
let (views, buffers, nulls) = match (self.truthy.as_ref(), self.falsy.as_ref()) {
822+
(Some(truthy), Some(falsy)) => {
823+
Self::create_output_on_non_nulls(predicate, result_len, truthy, falsy)
824+
}
825+
(Some(truthy), None) => {
826+
Self::get_scalar_buffers_and_nulls_for_single_non_nullable(predicate, truthy)
827+
}
828+
829+
(None, Some(falsy)) => {
830+
let predicate = predicate.not();
831+
Self::get_scalar_buffers_and_nulls_for_single_non_nullable(predicate, falsy)
832+
}
833+
(None, None) => Self::get_scalar_buffers_and_nulls_for_all_values_null(result_len),
834+
};
835+
836+
let result = unsafe { StringViewArray::new_unchecked(views, buffers, nulls) };
837+
Ok(Arc::new(result))
838+
}
839+
}
840+
660841
#[cfg(test)]
661842
mod test {
662843
use super::*;
@@ -1222,4 +1403,146 @@ mod test {
12221403
]);
12231404
assert_eq!(actual, &expected);
12241405
}
1406+
1407+
#[test]
1408+
fn test_zip_kernel_scalar_strings_array_view() {
1409+
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
1410+
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));
1411+
1412+
let mask = BooleanArray::from(vec![true, false, true, false]);
1413+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1414+
let actual = out.as_string_view();
1415+
let expected = StringViewArray::from(vec![
1416+
Some("hello"),
1417+
Some("world"),
1418+
Some("hello"),
1419+
Some("world"),
1420+
]);
1421+
assert_eq!(actual, &expected);
1422+
}
1423+
1424+
#[test]
1425+
fn test_zip_kernel_scalar_strings_array_view_with_nulls() {
1426+
let scalar_truthy = Scalar::new(StringViewArray::from_iter_values(["hello"]));
1427+
let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
1428+
1429+
let mask = BooleanArray::from(vec![true, true, false, false, true]);
1430+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1431+
let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
1432+
let expected = StringViewArray::from_iter(vec![
1433+
Some("hello"),
1434+
Some("hello"),
1435+
None,
1436+
None,
1437+
Some("hello"),
1438+
]);
1439+
assert_eq!(actual, &expected);
1440+
}
1441+
1442+
#[test]
1443+
fn test_zip_kernel_scalar_strings_array_view_all_true_null() {
1444+
let scalar_truthy = Scalar::new(StringViewArray::new_null(1));
1445+
let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
1446+
let mask = BooleanArray::from(vec![true, true]);
1447+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1448+
let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
1449+
let expected = StringViewArray::from_iter(vec![None::<String>, None]);
1450+
assert_eq!(actual, &expected);
1451+
}
1452+
1453+
#[test]
1454+
fn test_zip_kernel_scalar_strings_array_view_all_false_null() {
1455+
let scalar_truthy = Scalar::new(StringViewArray::new_null(1));
1456+
let scalar_falsy = Scalar::new(StringViewArray::new_null(1));
1457+
let mask = BooleanArray::from(vec![false, false]);
1458+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1459+
let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap();
1460+
let expected = StringViewArray::from_iter(vec![None::<String>, None]);
1461+
assert_eq!(actual, &expected);
1462+
}
1463+
1464+
#[test]
1465+
fn test_zip_kernel_scalar_string_array_view_all_true() {
1466+
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
1467+
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));
1468+
1469+
let mask = BooleanArray::from(vec![true, true]);
1470+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1471+
let actual = out.as_string_view();
1472+
let expected = StringViewArray::from(vec![Some("hello"), Some("hello")]);
1473+
assert_eq!(actual, &expected);
1474+
}
1475+
1476+
#[test]
1477+
fn test_zip_kernel_scalar_string_array_view_all_false() {
1478+
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
1479+
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"]));
1480+
1481+
let mask = BooleanArray::from(vec![false, false]);
1482+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1483+
let actual = out.as_string_view();
1484+
let expected = StringViewArray::from(vec![Some("world"), Some("world")]);
1485+
assert_eq!(actual, &expected);
1486+
}
1487+
1488+
#[test]
1489+
fn test_zip_kernel_scalar_strings_large_strings() {
1490+
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
1491+
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));
1492+
1493+
let mask = BooleanArray::from(vec![true, false]);
1494+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1495+
let actual = out.as_string_view();
1496+
let expected = StringViewArray::from(vec![
1497+
Some("longer than 12 bytes"),
1498+
Some("another longer than 12 bytes"),
1499+
]);
1500+
assert_eq!(actual, &expected);
1501+
}
1502+
1503+
#[test]
1504+
fn test_zip_kernel_scalar_strings_array_view_large_short_strings() {
1505+
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"]));
1506+
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
1507+
1508+
let mask = BooleanArray::from(vec![true, false, true, false]);
1509+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1510+
let actual = out.as_string_view();
1511+
let expected = StringViewArray::from(vec![
1512+
Some("hello"),
1513+
Some("longer than 12 bytes"),
1514+
Some("hello"),
1515+
Some("longer than 12 bytes"),
1516+
]);
1517+
assert_eq!(actual, &expected);
1518+
}
1519+
#[test]
1520+
fn test_zip_kernel_scalar_strings_array_view_large_all_true() {
1521+
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
1522+
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));
1523+
1524+
let mask = BooleanArray::from(vec![true, true]);
1525+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1526+
let actual = out.as_string_view();
1527+
let expected = StringViewArray::from(vec![
1528+
Some("longer than 12 bytes"),
1529+
Some("longer than 12 bytes"),
1530+
]);
1531+
assert_eq!(actual, &expected);
1532+
}
1533+
1534+
#[test]
1535+
fn test_zip_kernel_scalar_strings_array_view_large_all_false() {
1536+
let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"]));
1537+
let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"]));
1538+
1539+
let mask = BooleanArray::from(vec![false, false]);
1540+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
1541+
let actual = out.as_string_view();
1542+
let expected = StringViewArray::from(vec![
1543+
Some("another longer than 12 bytes"),
1544+
Some("another longer than 12 bytes"),
1545+
]);
1546+
assert_eq!(actual, &expected);
1547+
}
12251548
}

0 commit comments

Comments
 (0)