Skip to content

Commit daed6ab

Browse files
authored
Add dictionary array support for substring function (#1665)
* initial commit * add test * comments * more comments
1 parent e02869a commit daed6ab

File tree

1 file changed

+175
-85
lines changed

1 file changed

+175
-85
lines changed

arrow/src/compute/kernels/substring.rs

Lines changed: 175 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,137 @@
1818
//! Defines kernel to extract a substring of an Array
1919
//! Supported array types: \[Large\]StringArray, \[Large\]BinaryArray
2020
21+
use crate::array::DictionaryArray;
2122
use crate::buffer::MutableBuffer;
23+
use crate::datatypes::*;
2224
use crate::{array::*, buffer::Buffer};
2325
use crate::{
2426
datatypes::DataType,
2527
error::{ArrowError, Result},
2628
};
2729
use std::cmp::Ordering;
30+
use std::sync::Arc;
31+
32+
/// Returns an ArrayRef with substrings of all the elements in `array`.
33+
///
34+
/// # Arguments
35+
///
36+
/// * `start` - The start index of all substrings.
37+
/// If `start >= 0`, then count from the start of the string,
38+
/// otherwise count from the end of the string.
39+
///
40+
/// * `length`(option) - The length of all substrings.
41+
/// If `length` is `None`, then the substring is from `start` to the end of the string.
42+
///
43+
/// Attention: Both `start` and `length` are counted by byte, not by char.
44+
///
45+
/// # Basic usage
46+
/// ```
47+
/// # use arrow::array::StringArray;
48+
/// # use arrow::compute::kernels::substring::substring;
49+
/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]);
50+
/// let result = substring(&array, 1, Some(4)).unwrap();
51+
/// let result = result.as_any().downcast_ref::<StringArray>().unwrap();
52+
/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")]));
53+
/// ```
54+
///
55+
/// # Error
56+
/// - The function errors when the passed array is not a \[Large\]String array, \[Large\]Binary
57+
/// array, or DictionaryArray with \[Large\]String or \[Large\]Binary as its value type.
58+
/// - The function errors if the offset of a substring in the input array is at invalid char boundary (only for \[Large\]String array).
59+
///
60+
/// ## Example of trying to get an invalid utf-8 format substring
61+
/// ```
62+
/// # use arrow::array::StringArray;
63+
/// # use arrow::compute::kernels::substring::substring;
64+
/// let array = StringArray::from(vec![Some("E=mc²")]);
65+
/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string();
66+
/// assert!(error.contains("invalid utf-8 boundary"));
67+
/// ```
68+
pub fn substring(array: &dyn Array, start: i64, length: Option<u64>) -> Result<ArrayRef> {
69+
macro_rules! substring_dict {
70+
($kt: ident, $($t: ident: $gt: ident), *) => {
71+
match $kt.as_ref() {
72+
$(
73+
&DataType::$t => {
74+
let dict = array
75+
.as_any()
76+
.downcast_ref::<DictionaryArray<$gt>>()
77+
.unwrap_or_else(|| {
78+
panic!("Expect 'DictionaryArray<{}>' but got array of data type {:?}",
79+
stringify!($gt), array.data_type())
80+
});
81+
let values = substring(dict.values(), start, length)?;
82+
let result = DictionaryArray::try_new(dict.keys(), &values)?;
83+
Ok(Arc::new(result))
84+
},
85+
)*
86+
t => panic!("Unsupported dictionary key type: {}", t)
87+
}
88+
}
89+
}
90+
91+
match array.data_type() {
92+
DataType::Dictionary(kt, _) => {
93+
substring_dict!(
94+
kt,
95+
Int8: Int8Type,
96+
Int16: Int16Type,
97+
Int32: Int32Type,
98+
Int64: Int64Type,
99+
UInt8: UInt8Type,
100+
UInt16: UInt16Type,
101+
UInt32: UInt32Type,
102+
UInt64: UInt64Type
103+
)
104+
}
105+
DataType::LargeBinary => binary_substring(
106+
array
107+
.as_any()
108+
.downcast_ref::<LargeBinaryArray>()
109+
.expect("A large binary is expected"),
110+
start,
111+
length.map(|e| e as i64),
112+
),
113+
DataType::Binary => binary_substring(
114+
array
115+
.as_any()
116+
.downcast_ref::<BinaryArray>()
117+
.expect("A binary is expected"),
118+
start as i32,
119+
length.map(|e| e as i32),
120+
),
121+
DataType::FixedSizeBinary(old_len) => fixed_size_binary_substring(
122+
array
123+
.as_any()
124+
.downcast_ref::<FixedSizeBinaryArray>()
125+
.expect("a fixed size binary is expected"),
126+
*old_len,
127+
start as i32,
128+
length.map(|e| e as i32),
129+
),
130+
DataType::LargeUtf8 => utf8_substring(
131+
array
132+
.as_any()
133+
.downcast_ref::<LargeStringArray>()
134+
.expect("A large string is expected"),
135+
start,
136+
length.map(|e| e as i64),
137+
),
138+
DataType::Utf8 => utf8_substring(
139+
array
140+
.as_any()
141+
.downcast_ref::<StringArray>()
142+
.expect("A string is expected"),
143+
start as i32,
144+
length.map(|e| e as i32),
145+
),
146+
_ => Err(ArrowError::ComputeError(format!(
147+
"substring does not support type {:?}",
148+
array.data_type()
149+
))),
150+
}
151+
}
28152

29153
fn binary_substring<OffsetSize: OffsetSizeTrait>(
30154
array: &GenericBinaryArray<OffsetSize>,
@@ -215,94 +339,10 @@ fn utf8_substring<OffsetSize: OffsetSizeTrait>(
215339
Ok(make_array(data))
216340
}
217341

218-
/// Returns an ArrayRef with substrings of all the elements in `array`.
219-
///
220-
/// # Arguments
221-
///
222-
/// * `start` - The start index of all substrings.
223-
/// If `start >= 0`, then count from the start of the string,
224-
/// otherwise count from the end of the string.
225-
///
226-
/// * `length`(option) - The length of all substrings.
227-
/// If `length` is `None`, then the substring is from `start` to the end of the string.
228-
///
229-
/// Attention: Both `start` and `length` are counted by byte, not by char.
230-
///
231-
/// # Basic usage
232-
/// ```
233-
/// # use arrow::array::StringArray;
234-
/// # use arrow::compute::kernels::substring::substring;
235-
/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]);
236-
/// let result = substring(&array, 1, Some(4)).unwrap();
237-
/// let result = result.as_any().downcast_ref::<StringArray>().unwrap();
238-
/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")]));
239-
/// ```
240-
///
241-
/// # Error
242-
/// - The function errors when the passed array is not a \[Large\]String array or \[Large\]Binary array.
243-
/// - The function errors if the offset of a substring in the input array is at invalid char boundary (only for \[Large\]String array).
244-
///
245-
/// ## Example of trying to get an invalid utf-8 format substring
246-
/// ```
247-
/// # use arrow::array::StringArray;
248-
/// # use arrow::compute::kernels::substring::substring;
249-
/// let array = StringArray::from(vec![Some("E=mc²")]);
250-
/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string();
251-
/// assert!(error.contains("invalid utf-8 boundary"));
252-
/// ```
253-
pub fn substring(array: &dyn Array, start: i64, length: Option<u64>) -> Result<ArrayRef> {
254-
match array.data_type() {
255-
DataType::LargeBinary => binary_substring(
256-
array
257-
.as_any()
258-
.downcast_ref::<LargeBinaryArray>()
259-
.expect("A large binary is expected"),
260-
start,
261-
length.map(|e| e as i64),
262-
),
263-
DataType::Binary => binary_substring(
264-
array
265-
.as_any()
266-
.downcast_ref::<BinaryArray>()
267-
.expect("A binary is expected"),
268-
start as i32,
269-
length.map(|e| e as i32),
270-
),
271-
DataType::FixedSizeBinary(old_len) => fixed_size_binary_substring(
272-
array
273-
.as_any()
274-
.downcast_ref::<FixedSizeBinaryArray>()
275-
.expect("a fixed size binary is expected"),
276-
*old_len,
277-
start as i32,
278-
length.map(|e| e as i32),
279-
),
280-
DataType::LargeUtf8 => utf8_substring(
281-
array
282-
.as_any()
283-
.downcast_ref::<LargeStringArray>()
284-
.expect("A large string is expected"),
285-
start,
286-
length.map(|e| e as i64),
287-
),
288-
DataType::Utf8 => utf8_substring(
289-
array
290-
.as_any()
291-
.downcast_ref::<StringArray>()
292-
.expect("A string is expected"),
293-
start as i32,
294-
length.map(|e| e as i32),
295-
),
296-
_ => Err(ArrowError::ComputeError(format!(
297-
"substring does not support type {:?}",
298-
array.data_type()
299-
))),
300-
}
301-
}
302-
303342
#[cfg(test)]
304343
mod tests {
305344
use super::*;
345+
use crate::datatypes::*;
306346

307347
#[allow(clippy::type_complexity)]
308348
fn with_nulls_generic_binary<O: OffsetSizeTrait>() -> Result<()> {
@@ -954,6 +994,56 @@ mod tests {
954994
without_nulls_generic_string::<i64>()
955995
}
956996

997+
#[test]
998+
fn dictionary() -> Result<()> {
999+
_dictionary::<Int8Type>()?;
1000+
_dictionary::<Int16Type>()?;
1001+
_dictionary::<Int32Type>()?;
1002+
_dictionary::<Int64Type>()?;
1003+
_dictionary::<UInt8Type>()?;
1004+
_dictionary::<UInt16Type>()?;
1005+
_dictionary::<UInt32Type>()?;
1006+
_dictionary::<UInt64Type>()?;
1007+
Ok(())
1008+
}
1009+
1010+
fn _dictionary<K: ArrowDictionaryKeyType>() -> Result<()> {
1011+
const TOTAL: i32 = 100;
1012+
1013+
let v = ["aaa", "bbb", "ccc", "ddd", "eee"];
1014+
let data: Vec<Option<&str>> = (0..TOTAL)
1015+
.map(|n| {
1016+
let i = n % 5;
1017+
if i == 3 {
1018+
None
1019+
} else {
1020+
Some(v[i as usize])
1021+
}
1022+
})
1023+
.collect();
1024+
1025+
let dict_array: DictionaryArray<K> = data.clone().into_iter().collect();
1026+
1027+
let expected: Vec<Option<&str>> =
1028+
data.iter().map(|opt| opt.map(|s| &s[1..3])).collect();
1029+
1030+
let res = substring(&dict_array, 1, Some(2))?;
1031+
let actual = res.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
1032+
let actual: Vec<Option<&str>> = actual
1033+
.values()
1034+
.as_any()
1035+
.downcast_ref::<GenericStringArray<i32>>()
1036+
.unwrap()
1037+
.take_iter(actual.keys_iter())
1038+
.collect();
1039+
1040+
for i in 0..TOTAL as usize {
1041+
assert_eq!(expected[i], actual[i],);
1042+
}
1043+
1044+
Ok(())
1045+
}
1046+
9571047
#[test]
9581048
fn check_invalid_array_type() {
9591049
let array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);

0 commit comments

Comments
 (0)