|
18 | 18 | use std::any::Any; |
19 | 19 | use std::sync::Arc; |
20 | 20 |
|
21 | | -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; |
| 21 | +use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder}; |
22 | 22 | use arrow::datatypes::DataType; |
23 | 23 |
|
24 | 24 | use datafusion_common::cast::{as_generic_string_array, as_int64_array}; |
@@ -101,38 +101,151 @@ pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> { |
101 | 101 | let delimiter_array = as_generic_string_array::<T>(&args[1])?; |
102 | 102 | let count_array = as_int64_array(&args[2])?; |
103 | 103 |
|
104 | | - let result = string_array |
| 104 | + let mut builder = StringBuilder::new(); |
| 105 | + string_array |
105 | 106 | .iter() |
106 | 107 | .zip(delimiter_array.iter()) |
107 | 108 | .zip(count_array.iter()) |
108 | | - .map(|((string, delimiter), n)| match (string, delimiter, n) { |
| 109 | + .for_each(|((string, delimiter), n)| match (string, delimiter, n) { |
109 | 110 | (Some(string), Some(delimiter), Some(n)) => { |
110 | 111 | // In MySQL, these cases will return an empty string. |
111 | 112 | if n == 0 || string.is_empty() || delimiter.is_empty() { |
112 | | - return Some(String::new()); |
| 113 | + builder.append_value(""); |
| 114 | + return; |
113 | 115 | } |
114 | 116 |
|
115 | | - let splitted: Box<dyn Iterator<Item = _>> = if n > 0 { |
116 | | - Box::new(string.split(delimiter)) |
| 117 | + let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); |
| 118 | + let length = if n > 0 { |
| 119 | + let splitted = string.split(delimiter); |
| 120 | + splitted |
| 121 | + .take(occurrences) |
| 122 | + .map(|s| s.len() + delimiter.len()) |
| 123 | + .sum::<usize>() |
| 124 | + - delimiter.len() |
117 | 125 | } else { |
118 | | - Box::new(string.rsplit(delimiter)) |
| 126 | + let splitted = string.rsplit(delimiter); |
| 127 | + splitted |
| 128 | + .take(occurrences) |
| 129 | + .map(|s| s.len() + delimiter.len()) |
| 130 | + .sum::<usize>() |
| 131 | + - delimiter.len() |
119 | 132 | }; |
120 | | - let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); |
121 | | - // The length of the substring covered by substr_index. |
122 | | - let length = splitted |
123 | | - .take(occurrences) // at least 1 element, since n != 0 |
124 | | - .map(|s| s.len() + delimiter.len()) |
125 | | - .sum::<usize>() |
126 | | - - delimiter.len(); |
127 | 133 | if n > 0 { |
128 | | - Some(string[..length].to_owned()) |
| 134 | + match string.get(..length) { |
| 135 | + Some(substring) => builder.append_value(substring), |
| 136 | + None => builder.append_null(), |
| 137 | + } |
129 | 138 | } else { |
130 | | - Some(string[string.len() - length..].to_owned()) |
| 139 | + match string.get(string.len().saturating_sub(length)..) { |
| 140 | + Some(substring) => builder.append_value(substring), |
| 141 | + None => builder.append_null(), |
| 142 | + } |
131 | 143 | } |
132 | 144 | } |
133 | | - _ => None, |
134 | | - }) |
135 | | - .collect::<GenericStringArray<T>>(); |
| 145 | + _ => builder.append_null(), |
| 146 | + }); |
| 147 | + |
| 148 | + Ok(Arc::new(builder.finish()) as ArrayRef) |
| 149 | +} |
136 | 150 |
|
137 | | - Ok(Arc::new(result) as ArrayRef) |
| 151 | +#[cfg(test)] |
| 152 | +mod tests { |
| 153 | + use arrow::array::{Array, StringArray}; |
| 154 | + use arrow::datatypes::DataType::Utf8; |
| 155 | + |
| 156 | + use datafusion_common::{Result, ScalarValue}; |
| 157 | + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; |
| 158 | + |
| 159 | + use crate::unicode::substrindex::SubstrIndexFunc; |
| 160 | + use crate::utils::test::test_function; |
| 161 | + |
| 162 | + #[test] |
| 163 | + fn test_functions() -> Result<()> { |
| 164 | + test_function!( |
| 165 | + SubstrIndexFunc::new(), |
| 166 | + &[ |
| 167 | + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), |
| 168 | + ColumnarValue::Scalar(ScalarValue::from(".")), |
| 169 | + ColumnarValue::Scalar(ScalarValue::from(1i64)), |
| 170 | + ], |
| 171 | + Ok(Some("www")), |
| 172 | + &str, |
| 173 | + Utf8, |
| 174 | + StringArray |
| 175 | + ); |
| 176 | + test_function!( |
| 177 | + SubstrIndexFunc::new(), |
| 178 | + &[ |
| 179 | + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), |
| 180 | + ColumnarValue::Scalar(ScalarValue::from(".")), |
| 181 | + ColumnarValue::Scalar(ScalarValue::from(2i64)), |
| 182 | + ], |
| 183 | + Ok(Some("www.apache")), |
| 184 | + &str, |
| 185 | + Utf8, |
| 186 | + StringArray |
| 187 | + ); |
| 188 | + test_function!( |
| 189 | + SubstrIndexFunc::new(), |
| 190 | + &[ |
| 191 | + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), |
| 192 | + ColumnarValue::Scalar(ScalarValue::from(".")), |
| 193 | + ColumnarValue::Scalar(ScalarValue::from(-2i64)), |
| 194 | + ], |
| 195 | + Ok(Some("apache.org")), |
| 196 | + &str, |
| 197 | + Utf8, |
| 198 | + StringArray |
| 199 | + ); |
| 200 | + test_function!( |
| 201 | + SubstrIndexFunc::new(), |
| 202 | + &[ |
| 203 | + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), |
| 204 | + ColumnarValue::Scalar(ScalarValue::from(".")), |
| 205 | + ColumnarValue::Scalar(ScalarValue::from(-1i64)), |
| 206 | + ], |
| 207 | + Ok(Some("org")), |
| 208 | + &str, |
| 209 | + Utf8, |
| 210 | + StringArray |
| 211 | + ); |
| 212 | + test_function!( |
| 213 | + SubstrIndexFunc::new(), |
| 214 | + &[ |
| 215 | + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), |
| 216 | + ColumnarValue::Scalar(ScalarValue::from(".")), |
| 217 | + ColumnarValue::Scalar(ScalarValue::from(0i64)), |
| 218 | + ], |
| 219 | + Ok(Some("")), |
| 220 | + &str, |
| 221 | + Utf8, |
| 222 | + StringArray |
| 223 | + ); |
| 224 | + test_function!( |
| 225 | + SubstrIndexFunc::new(), |
| 226 | + &[ |
| 227 | + ColumnarValue::Scalar(ScalarValue::from("")), |
| 228 | + ColumnarValue::Scalar(ScalarValue::from(".")), |
| 229 | + ColumnarValue::Scalar(ScalarValue::from(1i64)), |
| 230 | + ], |
| 231 | + Ok(Some("")), |
| 232 | + &str, |
| 233 | + Utf8, |
| 234 | + StringArray |
| 235 | + ); |
| 236 | + test_function!( |
| 237 | + SubstrIndexFunc::new(), |
| 238 | + &[ |
| 239 | + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), |
| 240 | + ColumnarValue::Scalar(ScalarValue::from("")), |
| 241 | + ColumnarValue::Scalar(ScalarValue::from(1i64)), |
| 242 | + ], |
| 243 | + Ok(Some("")), |
| 244 | + &str, |
| 245 | + Utf8, |
| 246 | + StringArray |
| 247 | + ); |
| 248 | + |
| 249 | + Ok(()) |
| 250 | + } |
138 | 251 | } |
0 commit comments