|
18 | 18 | //! Defines kernel to extract a substring of an Array |
19 | 19 | //! Supported array types: \[Large\]StringArray, \[Large\]BinaryArray |
20 | 20 |
|
| 21 | +use crate::array::DictionaryArray; |
21 | 22 | use crate::buffer::MutableBuffer; |
| 23 | +use crate::datatypes::*; |
22 | 24 | use crate::{array::*, buffer::Buffer}; |
23 | 25 | use crate::{ |
24 | 26 | datatypes::DataType, |
25 | 27 | error::{ArrowError, Result}, |
26 | 28 | }; |
27 | 29 | use std::cmp::Ordering; |
| 30 | +use std::sync::Arc; |
| 31 | + |
| 32 | +/// Returns an ArrayRef with substrings of all the elements in `array`. |
| 33 | +/// |
| 34 | +/// # Arguments |
| 35 | +/// |
| 36 | +/// * `start` - The start index of all substrings. |
| 37 | +/// If `start >= 0`, then count from the start of the string, |
| 38 | +/// otherwise count from the end of the string. |
| 39 | +/// |
| 40 | +/// * `length`(option) - The length of all substrings. |
| 41 | +/// If `length` is `None`, then the substring is from `start` to the end of the string. |
| 42 | +/// |
| 43 | +/// Attention: Both `start` and `length` are counted by byte, not by char. |
| 44 | +/// |
| 45 | +/// # Basic usage |
| 46 | +/// ``` |
| 47 | +/// # use arrow::array::StringArray; |
| 48 | +/// # use arrow::compute::kernels::substring::substring; |
| 49 | +/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]); |
| 50 | +/// let result = substring(&array, 1, Some(4)).unwrap(); |
| 51 | +/// let result = result.as_any().downcast_ref::<StringArray>().unwrap(); |
| 52 | +/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")])); |
| 53 | +/// ``` |
| 54 | +/// |
| 55 | +/// # Error |
| 56 | +/// - The function errors when the passed array is not a \[Large\]String array, \[Large\]Binary |
| 57 | +/// array, or DictionaryArray with \[Large\]String or \[Large\]Binary as its value type. |
| 58 | +/// - The function errors if the offset of a substring in the input array is at invalid char boundary (only for \[Large\]String array). |
| 59 | +/// |
| 60 | +/// ## Example of trying to get an invalid utf-8 format substring |
| 61 | +/// ``` |
| 62 | +/// # use arrow::array::StringArray; |
| 63 | +/// # use arrow::compute::kernels::substring::substring; |
| 64 | +/// let array = StringArray::from(vec![Some("E=mc²")]); |
| 65 | +/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string(); |
| 66 | +/// assert!(error.contains("invalid utf-8 boundary")); |
| 67 | +/// ``` |
| 68 | +pub fn substring(array: &dyn Array, start: i64, length: Option<u64>) -> Result<ArrayRef> { |
| 69 | + macro_rules! substring_dict { |
| 70 | + ($kt: ident, $($t: ident: $gt: ident), *) => { |
| 71 | + match $kt.as_ref() { |
| 72 | + $( |
| 73 | + &DataType::$t => { |
| 74 | + let dict = array |
| 75 | + .as_any() |
| 76 | + .downcast_ref::<DictionaryArray<$gt>>() |
| 77 | + .unwrap_or_else(|| { |
| 78 | + panic!("Expect 'DictionaryArray<{}>' but got array of data type {:?}", |
| 79 | + stringify!($gt), array.data_type()) |
| 80 | + }); |
| 81 | + let values = substring(dict.values(), start, length)?; |
| 82 | + let result = DictionaryArray::try_new(dict.keys(), &values)?; |
| 83 | + Ok(Arc::new(result)) |
| 84 | + }, |
| 85 | + )* |
| 86 | + t => panic!("Unsupported dictionary key type: {}", t) |
| 87 | + } |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + match array.data_type() { |
| 92 | + DataType::Dictionary(kt, _) => { |
| 93 | + substring_dict!( |
| 94 | + kt, |
| 95 | + Int8: Int8Type, |
| 96 | + Int16: Int16Type, |
| 97 | + Int32: Int32Type, |
| 98 | + Int64: Int64Type, |
| 99 | + UInt8: UInt8Type, |
| 100 | + UInt16: UInt16Type, |
| 101 | + UInt32: UInt32Type, |
| 102 | + UInt64: UInt64Type |
| 103 | + ) |
| 104 | + } |
| 105 | + DataType::LargeBinary => binary_substring( |
| 106 | + array |
| 107 | + .as_any() |
| 108 | + .downcast_ref::<LargeBinaryArray>() |
| 109 | + .expect("A large binary is expected"), |
| 110 | + start, |
| 111 | + length.map(|e| e as i64), |
| 112 | + ), |
| 113 | + DataType::Binary => binary_substring( |
| 114 | + array |
| 115 | + .as_any() |
| 116 | + .downcast_ref::<BinaryArray>() |
| 117 | + .expect("A binary is expected"), |
| 118 | + start as i32, |
| 119 | + length.map(|e| e as i32), |
| 120 | + ), |
| 121 | + DataType::FixedSizeBinary(old_len) => fixed_size_binary_substring( |
| 122 | + array |
| 123 | + .as_any() |
| 124 | + .downcast_ref::<FixedSizeBinaryArray>() |
| 125 | + .expect("a fixed size binary is expected"), |
| 126 | + *old_len, |
| 127 | + start as i32, |
| 128 | + length.map(|e| e as i32), |
| 129 | + ), |
| 130 | + DataType::LargeUtf8 => utf8_substring( |
| 131 | + array |
| 132 | + .as_any() |
| 133 | + .downcast_ref::<LargeStringArray>() |
| 134 | + .expect("A large string is expected"), |
| 135 | + start, |
| 136 | + length.map(|e| e as i64), |
| 137 | + ), |
| 138 | + DataType::Utf8 => utf8_substring( |
| 139 | + array |
| 140 | + .as_any() |
| 141 | + .downcast_ref::<StringArray>() |
| 142 | + .expect("A string is expected"), |
| 143 | + start as i32, |
| 144 | + length.map(|e| e as i32), |
| 145 | + ), |
| 146 | + _ => Err(ArrowError::ComputeError(format!( |
| 147 | + "substring does not support type {:?}", |
| 148 | + array.data_type() |
| 149 | + ))), |
| 150 | + } |
| 151 | +} |
28 | 152 |
|
29 | 153 | fn binary_substring<OffsetSize: OffsetSizeTrait>( |
30 | 154 | array: &GenericBinaryArray<OffsetSize>, |
@@ -215,94 +339,10 @@ fn utf8_substring<OffsetSize: OffsetSizeTrait>( |
215 | 339 | Ok(make_array(data)) |
216 | 340 | } |
217 | 341 |
|
218 | | -/// Returns an ArrayRef with substrings of all the elements in `array`. |
219 | | -/// |
220 | | -/// # Arguments |
221 | | -/// |
222 | | -/// * `start` - The start index of all substrings. |
223 | | -/// If `start >= 0`, then count from the start of the string, |
224 | | -/// otherwise count from the end of the string. |
225 | | -/// |
226 | | -/// * `length`(option) - The length of all substrings. |
227 | | -/// If `length` is `None`, then the substring is from `start` to the end of the string. |
228 | | -/// |
229 | | -/// Attention: Both `start` and `length` are counted by byte, not by char. |
230 | | -/// |
231 | | -/// # Basic usage |
232 | | -/// ``` |
233 | | -/// # use arrow::array::StringArray; |
234 | | -/// # use arrow::compute::kernels::substring::substring; |
235 | | -/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]); |
236 | | -/// let result = substring(&array, 1, Some(4)).unwrap(); |
237 | | -/// let result = result.as_any().downcast_ref::<StringArray>().unwrap(); |
238 | | -/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")])); |
239 | | -/// ``` |
240 | | -/// |
241 | | -/// # Error |
242 | | -/// - The function errors when the passed array is not a \[Large\]String array or \[Large\]Binary array. |
243 | | -/// - The function errors if the offset of a substring in the input array is at invalid char boundary (only for \[Large\]String array). |
244 | | -/// |
245 | | -/// ## Example of trying to get an invalid utf-8 format substring |
246 | | -/// ``` |
247 | | -/// # use arrow::array::StringArray; |
248 | | -/// # use arrow::compute::kernels::substring::substring; |
249 | | -/// let array = StringArray::from(vec![Some("E=mc²")]); |
250 | | -/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string(); |
251 | | -/// assert!(error.contains("invalid utf-8 boundary")); |
252 | | -/// ``` |
253 | | -pub fn substring(array: &dyn Array, start: i64, length: Option<u64>) -> Result<ArrayRef> { |
254 | | - match array.data_type() { |
255 | | - DataType::LargeBinary => binary_substring( |
256 | | - array |
257 | | - .as_any() |
258 | | - .downcast_ref::<LargeBinaryArray>() |
259 | | - .expect("A large binary is expected"), |
260 | | - start, |
261 | | - length.map(|e| e as i64), |
262 | | - ), |
263 | | - DataType::Binary => binary_substring( |
264 | | - array |
265 | | - .as_any() |
266 | | - .downcast_ref::<BinaryArray>() |
267 | | - .expect("A binary is expected"), |
268 | | - start as i32, |
269 | | - length.map(|e| e as i32), |
270 | | - ), |
271 | | - DataType::FixedSizeBinary(old_len) => fixed_size_binary_substring( |
272 | | - array |
273 | | - .as_any() |
274 | | - .downcast_ref::<FixedSizeBinaryArray>() |
275 | | - .expect("a fixed size binary is expected"), |
276 | | - *old_len, |
277 | | - start as i32, |
278 | | - length.map(|e| e as i32), |
279 | | - ), |
280 | | - DataType::LargeUtf8 => utf8_substring( |
281 | | - array |
282 | | - .as_any() |
283 | | - .downcast_ref::<LargeStringArray>() |
284 | | - .expect("A large string is expected"), |
285 | | - start, |
286 | | - length.map(|e| e as i64), |
287 | | - ), |
288 | | - DataType::Utf8 => utf8_substring( |
289 | | - array |
290 | | - .as_any() |
291 | | - .downcast_ref::<StringArray>() |
292 | | - .expect("A string is expected"), |
293 | | - start as i32, |
294 | | - length.map(|e| e as i32), |
295 | | - ), |
296 | | - _ => Err(ArrowError::ComputeError(format!( |
297 | | - "substring does not support type {:?}", |
298 | | - array.data_type() |
299 | | - ))), |
300 | | - } |
301 | | -} |
302 | | - |
303 | 342 | #[cfg(test)] |
304 | 343 | mod tests { |
305 | 344 | use super::*; |
| 345 | + use crate::datatypes::*; |
306 | 346 |
|
307 | 347 | #[allow(clippy::type_complexity)] |
308 | 348 | fn with_nulls_generic_binary<O: OffsetSizeTrait>() -> Result<()> { |
@@ -954,6 +994,56 @@ mod tests { |
954 | 994 | without_nulls_generic_string::<i64>() |
955 | 995 | } |
956 | 996 |
|
| 997 | + #[test] |
| 998 | + fn dictionary() -> Result<()> { |
| 999 | + _dictionary::<Int8Type>()?; |
| 1000 | + _dictionary::<Int16Type>()?; |
| 1001 | + _dictionary::<Int32Type>()?; |
| 1002 | + _dictionary::<Int64Type>()?; |
| 1003 | + _dictionary::<UInt8Type>()?; |
| 1004 | + _dictionary::<UInt16Type>()?; |
| 1005 | + _dictionary::<UInt32Type>()?; |
| 1006 | + _dictionary::<UInt64Type>()?; |
| 1007 | + Ok(()) |
| 1008 | + } |
| 1009 | + |
| 1010 | + fn _dictionary<K: ArrowDictionaryKeyType>() -> Result<()> { |
| 1011 | + const TOTAL: i32 = 100; |
| 1012 | + |
| 1013 | + let v = ["aaa", "bbb", "ccc", "ddd", "eee"]; |
| 1014 | + let data: Vec<Option<&str>> = (0..TOTAL) |
| 1015 | + .map(|n| { |
| 1016 | + let i = n % 5; |
| 1017 | + if i == 3 { |
| 1018 | + None |
| 1019 | + } else { |
| 1020 | + Some(v[i as usize]) |
| 1021 | + } |
| 1022 | + }) |
| 1023 | + .collect(); |
| 1024 | + |
| 1025 | + let dict_array: DictionaryArray<K> = data.clone().into_iter().collect(); |
| 1026 | + |
| 1027 | + let expected: Vec<Option<&str>> = |
| 1028 | + data.iter().map(|opt| opt.map(|s| &s[1..3])).collect(); |
| 1029 | + |
| 1030 | + let res = substring(&dict_array, 1, Some(2))?; |
| 1031 | + let actual = res.as_any().downcast_ref::<DictionaryArray<K>>().unwrap(); |
| 1032 | + let actual: Vec<Option<&str>> = actual |
| 1033 | + .values() |
| 1034 | + .as_any() |
| 1035 | + .downcast_ref::<GenericStringArray<i32>>() |
| 1036 | + .unwrap() |
| 1037 | + .take_iter(actual.keys_iter()) |
| 1038 | + .collect(); |
| 1039 | + |
| 1040 | + for i in 0..TOTAL as usize { |
| 1041 | + assert_eq!(expected[i], actual[i],); |
| 1042 | + } |
| 1043 | + |
| 1044 | + Ok(()) |
| 1045 | + } |
| 1046 | + |
957 | 1047 | #[test] |
958 | 1048 | fn check_invalid_array_type() { |
959 | 1049 | let array = Int32Array::from(vec![Some(1), Some(2), Some(3)]); |
|
0 commit comments