|
19 | 19 |
|
20 | 20 | use crate::filter::{SlicesIterator, prep_null_mask_filter}; |
21 | 21 | use arrow_array::cast::AsArray; |
22 | | -use arrow_array::types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, Utf8Type}; |
| 22 | +use arrow_array::types::{ |
| 23 | + BinaryType, ByteArrayType, ByteViewType, LargeBinaryType, LargeUtf8Type, StringViewType, |
| 24 | + Utf8Type, |
| 25 | +}; |
23 | 26 | use arrow_array::*; |
24 | 27 | use arrow_buffer::{ |
25 | 28 | BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, OffsetBufferBuilder, |
26 | | - ScalarBuffer, |
| 29 | + ScalarBuffer, ToByteSlice, |
27 | 30 | }; |
28 | | -use arrow_data::ArrayData; |
29 | 31 | use arrow_data::transform::MutableArrayData; |
| 32 | +use arrow_data::{ArrayData, ByteView}; |
30 | 33 | use arrow_schema::{ArrowError, DataType}; |
31 | 34 | use std::fmt::{Debug, Formatter}; |
32 | 35 | use std::hash::Hash; |
@@ -284,7 +287,9 @@ impl ScalarZipper { |
284 | 287 | DataType::LargeBinary => { |
285 | 288 | Arc::new(BytesScalarImpl::<LargeBinaryType>::new(truthy, falsy)) as Arc<dyn ZipImpl> |
286 | 289 | }, |
287 | | - // TODO: Handle Utf8View https://github.com/apache/arrow-rs/issues/8724 |
| 290 | + DataType::Utf8View => { |
| 291 | + Arc::new(ByteViewScalarImpl::<StringViewType>::new(truthy, falsy)) as Arc<dyn ZipImpl> |
| 292 | + }, |
288 | 293 | _ => { |
289 | 294 | Arc::new(FallbackImpl::new(truthy, falsy)) as Arc<dyn ZipImpl> |
290 | 295 | }, |
@@ -657,6 +662,182 @@ fn maybe_prep_null_mask_filter(predicate: &BooleanArray) -> BooleanBuffer { |
657 | 662 | } |
658 | 663 | } |
659 | 664 |
|
| 665 | +struct ByteViewScalarImpl<T: ByteViewType> { |
| 666 | + truthy: Option<StringViewArray>, |
| 667 | + falsy: Option<StringViewArray>, |
| 668 | + phantom: PhantomData<T>, |
| 669 | +} |
| 670 | + |
| 671 | +impl<T: ByteViewType> ByteViewScalarImpl<T> { |
| 672 | + fn new(truthy: &dyn Array, falsy: &dyn Array) -> Self { |
| 673 | + Self { |
| 674 | + truthy: Self::get_value_from_scalar(truthy), |
| 675 | + falsy: Self::get_value_from_scalar(falsy), |
| 676 | + phantom: PhantomData, |
| 677 | + } |
| 678 | + } |
| 679 | + |
| 680 | + fn get_value_from_scalar(scalar: &dyn Array) -> Option<StringViewArray> { |
| 681 | + if scalar.is_null(0) { |
| 682 | + None |
| 683 | + } else { |
| 684 | + Some(scalar.as_string_view().clone()) |
| 685 | + } |
| 686 | + } |
| 687 | + |
| 688 | + fn get_scalar_buffers_and_nulls_for_all_values_null( |
| 689 | + len: usize, |
| 690 | + ) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) { |
| 691 | + let mut mutable = MutableBuffer::with_capacity(0); |
| 692 | + mutable.repeat_slice_n_times((0u128).to_byte_slice(), len); |
| 693 | + |
| 694 | + (mutable.into(), vec![], Some(NullBuffer::new_null(len))) |
| 695 | + } |
| 696 | + |
| 697 | + fn get_scalar_buffers_and_nulls_for_single_non_nullable( |
| 698 | + predicate: BooleanBuffer, |
| 699 | + value: &StringViewArray, |
| 700 | + ) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) { |
| 701 | + let number_of_true = predicate.count_set_bits(); |
| 702 | + let number_of_values = predicate.len(); |
| 703 | + |
| 704 | + // Fast path for all nulls |
| 705 | + if number_of_true == 0 { |
| 706 | + // All values are null |
| 707 | + return Self::get_scalar_buffers_and_nulls_for_all_values_null(number_of_values); |
| 708 | + } |
| 709 | + let view = value.views()[0].to_byte_slice(); |
| 710 | + let mut bytes = MutableBuffer::with_capacity(0); |
| 711 | + bytes.repeat_slice_n_times(view, number_of_values); |
| 712 | + |
| 713 | + let bytes = Buffer::from(bytes); |
| 714 | + |
| 715 | + // If a value is true we need the TRUTHY and the null buffer will have 1 (meaning not null) |
| 716 | + // If a value is false we need the FALSY and the null buffer will have 0 (meaning null) |
| 717 | + let nulls = NullBuffer::new(predicate); |
| 718 | + (bytes.into(), value.data_buffers().into(), Some(nulls)) |
| 719 | + } |
| 720 | + |
| 721 | + fn get_scalar_buffers_and_nulls_non_nullable( |
| 722 | + predicate: BooleanBuffer, |
| 723 | + truthy: &StringViewArray, |
| 724 | + falsy: &StringViewArray, |
| 725 | + ) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) { |
| 726 | + let true_count = predicate.count_set_bits(); |
| 727 | + let view_truthy = truthy.views()[0].to_byte_slice(); |
| 728 | + let mut buffers: Vec<Buffer> = truthy.data_buffers().to_vec(); |
| 729 | + |
| 730 | + // if falsy has non-inlined values in the buffer, |
| 731 | + // include the buffers and recalculate the view, |
| 732 | + // otherwise, we simply use the view. |
| 733 | + let view_falsy = if falsy.total_buffer_bytes_used() > 0 { |
| 734 | + let byte_view_falsy = ByteView::from(falsy.views()[0]); |
| 735 | + let new_index_falsy_buffers = buffers.len() as u32; |
| 736 | + buffers.extend(falsy.data_buffers().to_vec()); |
| 737 | + let byte_view_falsy = byte_view_falsy.with_buffer_index(new_index_falsy_buffers); |
| 738 | + byte_view_falsy.as_u128() |
| 739 | + } else { |
| 740 | + falsy.views()[0] |
| 741 | + }; |
| 742 | + |
| 743 | + let total_number_of_bytes = true_count * view_truthy.len() |
| 744 | + + (predicate.len() - true_count) * view_falsy.to_byte_slice().len(); |
| 745 | + let mut mutable = MutableBuffer::new(total_number_of_bytes); |
| 746 | + let mut filled = 0; |
| 747 | + |
| 748 | + SlicesIterator::from(&predicate).for_each(|(start, end)| { |
| 749 | + if start > filled { |
| 750 | + let false_repeat_count = start - filled; |
| 751 | + mutable.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count); |
| 752 | + } |
| 753 | + let true_repeat_count = end - start; |
| 754 | + mutable.repeat_slice_n_times(view_truthy, true_repeat_count); |
| 755 | + filled = end; |
| 756 | + }); |
| 757 | + |
| 758 | + if filled < predicate.len() { |
| 759 | + let false_repeat_count = predicate.len() - filled; |
| 760 | + mutable.repeat_slice_n_times(view_falsy.to_byte_slice(), false_repeat_count); |
| 761 | + } |
| 762 | + |
| 763 | + let bytes = Buffer::from(mutable); |
| 764 | + |
| 765 | + ( |
| 766 | + bytes.into(), |
| 767 | + buffers, |
| 768 | + Some(NullBuffer::new_valid(predicate.len())), |
| 769 | + ) |
| 770 | + } |
| 771 | + |
| 772 | + fn get_scalar_buffers_and_nulls_for_all_same_value( |
| 773 | + length: usize, |
| 774 | + value: &StringViewArray, |
| 775 | + ) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) { |
| 776 | + let (views, buffers, _) = value.clone().into_parts(); |
| 777 | + let mut mutable = MutableBuffer::with_capacity(0); |
| 778 | + mutable.repeat_slice_n_times(views[0].to_byte_slice(), length); |
| 779 | + |
| 780 | + let bytes = Buffer::from(mutable); |
| 781 | + |
| 782 | + (bytes.into(), buffers, Some(NullBuffer::new_valid(length))) |
| 783 | + } |
| 784 | + |
| 785 | + fn create_output_on_non_nulls( |
| 786 | + predicate: BooleanBuffer, |
| 787 | + result_len: usize, |
| 788 | + truthy: &StringViewArray, |
| 789 | + falsy: &StringViewArray, |
| 790 | + ) -> (ScalarBuffer<u128>, Vec<Buffer>, Option<NullBuffer>) { |
| 791 | + let true_count = predicate.count_set_bits(); |
| 792 | + match true_count { |
| 793 | + 0 => { |
| 794 | + // all values are falsy |
| 795 | + Self::get_scalar_buffers_and_nulls_for_all_same_value(result_len, falsy) |
| 796 | + } |
| 797 | + n if n == predicate.len() => { |
| 798 | + // all values are truthy |
| 799 | + Self::get_scalar_buffers_and_nulls_for_all_same_value(result_len, truthy) |
| 800 | + } |
| 801 | + _ => Self::get_scalar_buffers_and_nulls_non_nullable(predicate, truthy, falsy), |
| 802 | + } |
| 803 | + } |
| 804 | +} |
| 805 | + |
| 806 | +impl<T: ByteViewType> Debug for ByteViewScalarImpl<T> { |
| 807 | + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
| 808 | + f.debug_struct("ByteViewScalarImpl") |
| 809 | + .field("truthy", &self.truthy) |
| 810 | + .field("falsy", &self.falsy) |
| 811 | + .finish() |
| 812 | + } |
| 813 | +} |
| 814 | + |
| 815 | +impl<T: ByteViewType> ZipImpl for ByteViewScalarImpl<T> { |
| 816 | + fn create_output(&self, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> { |
| 817 | + let result_len = predicate.len(); |
| 818 | + // Nulls are treated as false |
| 819 | + let predicate = maybe_prep_null_mask_filter(predicate); |
| 820 | + |
| 821 | + let (views, buffers, nulls) = match (self.truthy.as_ref(), self.falsy.as_ref()) { |
| 822 | + (Some(truthy), Some(falsy)) => { |
| 823 | + Self::create_output_on_non_nulls(predicate, result_len, truthy, falsy) |
| 824 | + } |
| 825 | + (Some(truthy), None) => { |
| 826 | + Self::get_scalar_buffers_and_nulls_for_single_non_nullable(predicate, truthy) |
| 827 | + } |
| 828 | + |
| 829 | + (None, Some(falsy)) => { |
| 830 | + let predicate = predicate.not(); |
| 831 | + Self::get_scalar_buffers_and_nulls_for_single_non_nullable(predicate, falsy) |
| 832 | + } |
| 833 | + (None, None) => Self::get_scalar_buffers_and_nulls_for_all_values_null(result_len), |
| 834 | + }; |
| 835 | + |
| 836 | + let result = unsafe { StringViewArray::new_unchecked(views, buffers, nulls) }; |
| 837 | + Ok(Arc::new(result)) |
| 838 | + } |
| 839 | +} |
| 840 | + |
660 | 841 | #[cfg(test)] |
661 | 842 | mod test { |
662 | 843 | use super::*; |
@@ -1222,4 +1403,146 @@ mod test { |
1222 | 1403 | ]); |
1223 | 1404 | assert_eq!(actual, &expected); |
1224 | 1405 | } |
| 1406 | + |
| 1407 | + #[test] |
| 1408 | + fn test_zip_kernel_scalar_strings_array_view() { |
| 1409 | + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"])); |
| 1410 | + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"])); |
| 1411 | + |
| 1412 | + let mask = BooleanArray::from(vec![true, false, true, false]); |
| 1413 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1414 | + let actual = out.as_string_view(); |
| 1415 | + let expected = StringViewArray::from(vec![ |
| 1416 | + Some("hello"), |
| 1417 | + Some("world"), |
| 1418 | + Some("hello"), |
| 1419 | + Some("world"), |
| 1420 | + ]); |
| 1421 | + assert_eq!(actual, &expected); |
| 1422 | + } |
| 1423 | + |
| 1424 | + #[test] |
| 1425 | + fn test_zip_kernel_scalar_strings_array_view_with_nulls() { |
| 1426 | + let scalar_truthy = Scalar::new(StringViewArray::from_iter_values(["hello"])); |
| 1427 | + let scalar_falsy = Scalar::new(StringViewArray::new_null(1)); |
| 1428 | + |
| 1429 | + let mask = BooleanArray::from(vec![true, true, false, false, true]); |
| 1430 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1431 | + let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap(); |
| 1432 | + let expected = StringViewArray::from_iter(vec![ |
| 1433 | + Some("hello"), |
| 1434 | + Some("hello"), |
| 1435 | + None, |
| 1436 | + None, |
| 1437 | + Some("hello"), |
| 1438 | + ]); |
| 1439 | + assert_eq!(actual, &expected); |
| 1440 | + } |
| 1441 | + |
| 1442 | + #[test] |
| 1443 | + fn test_zip_kernel_scalar_strings_array_view_all_true_null() { |
| 1444 | + let scalar_truthy = Scalar::new(StringViewArray::new_null(1)); |
| 1445 | + let scalar_falsy = Scalar::new(StringViewArray::new_null(1)); |
| 1446 | + let mask = BooleanArray::from(vec![true, true]); |
| 1447 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1448 | + let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap(); |
| 1449 | + let expected = StringViewArray::from_iter(vec![None::<String>, None]); |
| 1450 | + assert_eq!(actual, &expected); |
| 1451 | + } |
| 1452 | + |
| 1453 | + #[test] |
| 1454 | + fn test_zip_kernel_scalar_strings_array_view_all_false_null() { |
| 1455 | + let scalar_truthy = Scalar::new(StringViewArray::new_null(1)); |
| 1456 | + let scalar_falsy = Scalar::new(StringViewArray::new_null(1)); |
| 1457 | + let mask = BooleanArray::from(vec![false, false]); |
| 1458 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1459 | + let actual = out.as_any().downcast_ref::<StringViewArray>().unwrap(); |
| 1460 | + let expected = StringViewArray::from_iter(vec![None::<String>, None]); |
| 1461 | + assert_eq!(actual, &expected); |
| 1462 | + } |
| 1463 | + |
| 1464 | + #[test] |
| 1465 | + fn test_zip_kernel_scalar_string_array_view_all_true() { |
| 1466 | + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"])); |
| 1467 | + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"])); |
| 1468 | + |
| 1469 | + let mask = BooleanArray::from(vec![true, true]); |
| 1470 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1471 | + let actual = out.as_string_view(); |
| 1472 | + let expected = StringViewArray::from(vec![Some("hello"), Some("hello")]); |
| 1473 | + assert_eq!(actual, &expected); |
| 1474 | + } |
| 1475 | + |
| 1476 | + #[test] |
| 1477 | + fn test_zip_kernel_scalar_string_array_view_all_false() { |
| 1478 | + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"])); |
| 1479 | + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["world"])); |
| 1480 | + |
| 1481 | + let mask = BooleanArray::from(vec![false, false]); |
| 1482 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1483 | + let actual = out.as_string_view(); |
| 1484 | + let expected = StringViewArray::from(vec![Some("world"), Some("world")]); |
| 1485 | + assert_eq!(actual, &expected); |
| 1486 | + } |
| 1487 | + |
| 1488 | + #[test] |
| 1489 | + fn test_zip_kernel_scalar_strings_large_strings() { |
| 1490 | + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"])); |
| 1491 | + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"])); |
| 1492 | + |
| 1493 | + let mask = BooleanArray::from(vec![true, false]); |
| 1494 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1495 | + let actual = out.as_string_view(); |
| 1496 | + let expected = StringViewArray::from(vec![ |
| 1497 | + Some("longer than 12 bytes"), |
| 1498 | + Some("another longer than 12 bytes"), |
| 1499 | + ]); |
| 1500 | + assert_eq!(actual, &expected); |
| 1501 | + } |
| 1502 | + |
| 1503 | + #[test] |
| 1504 | + fn test_zip_kernel_scalar_strings_array_view_large_short_strings() { |
| 1505 | + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["hello"])); |
| 1506 | + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"])); |
| 1507 | + |
| 1508 | + let mask = BooleanArray::from(vec![true, false, true, false]); |
| 1509 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1510 | + let actual = out.as_string_view(); |
| 1511 | + let expected = StringViewArray::from(vec![ |
| 1512 | + Some("hello"), |
| 1513 | + Some("longer than 12 bytes"), |
| 1514 | + Some("hello"), |
| 1515 | + Some("longer than 12 bytes"), |
| 1516 | + ]); |
| 1517 | + assert_eq!(actual, &expected); |
| 1518 | + } |
| 1519 | + #[test] |
| 1520 | + fn test_zip_kernel_scalar_strings_array_view_large_all_true() { |
| 1521 | + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"])); |
| 1522 | + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"])); |
| 1523 | + |
| 1524 | + let mask = BooleanArray::from(vec![true, true]); |
| 1525 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1526 | + let actual = out.as_string_view(); |
| 1527 | + let expected = StringViewArray::from(vec![ |
| 1528 | + Some("longer than 12 bytes"), |
| 1529 | + Some("longer than 12 bytes"), |
| 1530 | + ]); |
| 1531 | + assert_eq!(actual, &expected); |
| 1532 | + } |
| 1533 | + |
| 1534 | + #[test] |
| 1535 | + fn test_zip_kernel_scalar_strings_array_view_large_all_false() { |
| 1536 | + let scalar_truthy = Scalar::new(StringViewArray::from(vec!["longer than 12 bytes"])); |
| 1537 | + let scalar_falsy = Scalar::new(StringViewArray::from(vec!["another longer than 12 bytes"])); |
| 1538 | + |
| 1539 | + let mask = BooleanArray::from(vec![false, false]); |
| 1540 | + let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap(); |
| 1541 | + let actual = out.as_string_view(); |
| 1542 | + let expected = StringViewArray::from(vec![ |
| 1543 | + Some("another longer than 12 bytes"), |
| 1544 | + Some("another longer than 12 bytes"), |
| 1545 | + ]); |
| 1546 | + assert_eq!(actual, &expected); |
| 1547 | + } |
1225 | 1548 | } |
0 commit comments