Skip to content

Commit 8c9e567

Browse files
Optimize performance of substr_index and add tests (#9973)
* Optimize performance of substr_index
1 parent 86ad8a5 commit 8c9e567

File tree

2 files changed

+143
-21
lines changed

2 files changed

+143
-21
lines changed

datafusion/functions/src/unicode/substrindex.rs

Lines changed: 133 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
21+
use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder};
2222
use arrow::datatypes::DataType;
2323

2424
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
@@ -101,38 +101,151 @@ pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
101101
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
102102
let count_array = as_int64_array(&args[2])?;
103103

104-
let result = string_array
104+
let mut builder = StringBuilder::new();
105+
string_array
105106
.iter()
106107
.zip(delimiter_array.iter())
107108
.zip(count_array.iter())
108-
.map(|((string, delimiter), n)| match (string, delimiter, n) {
109+
.for_each(|((string, delimiter), n)| match (string, delimiter, n) {
109110
(Some(string), Some(delimiter), Some(n)) => {
110111
// In MySQL, these cases will return an empty string.
111112
if n == 0 || string.is_empty() || delimiter.is_empty() {
112-
return Some(String::new());
113+
builder.append_value("");
114+
return;
113115
}
114116

115-
let splitted: Box<dyn Iterator<Item = _>> = if n > 0 {
116-
Box::new(string.split(delimiter))
117+
let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
118+
let length = if n > 0 {
119+
let splitted = string.split(delimiter);
120+
splitted
121+
.take(occurrences)
122+
.map(|s| s.len() + delimiter.len())
123+
.sum::<usize>()
124+
- delimiter.len()
117125
} else {
118-
Box::new(string.rsplit(delimiter))
126+
let splitted = string.rsplit(delimiter);
127+
splitted
128+
.take(occurrences)
129+
.map(|s| s.len() + delimiter.len())
130+
.sum::<usize>()
131+
- delimiter.len()
119132
};
120-
let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
121-
// The length of the substring covered by substr_index.
122-
let length = splitted
123-
.take(occurrences) // at least 1 element, since n != 0
124-
.map(|s| s.len() + delimiter.len())
125-
.sum::<usize>()
126-
- delimiter.len();
127133
if n > 0 {
128-
Some(string[..length].to_owned())
134+
match string.get(..length) {
135+
Some(substring) => builder.append_value(substring),
136+
None => builder.append_null(),
137+
}
129138
} else {
130-
Some(string[string.len() - length..].to_owned())
139+
match string.get(string.len().saturating_sub(length)..) {
140+
Some(substring) => builder.append_value(substring),
141+
None => builder.append_null(),
142+
}
131143
}
132144
}
133-
_ => None,
134-
})
135-
.collect::<GenericStringArray<T>>();
145+
_ => builder.append_null(),
146+
});
147+
148+
Ok(Arc::new(builder.finish()) as ArrayRef)
149+
}
136150

137-
Ok(Arc::new(result) as ArrayRef)
151+
#[cfg(test)]
152+
mod tests {
153+
use arrow::array::{Array, StringArray};
154+
use arrow::datatypes::DataType::Utf8;
155+
156+
use datafusion_common::{Result, ScalarValue};
157+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
158+
159+
use crate::unicode::substrindex::SubstrIndexFunc;
160+
use crate::utils::test::test_function;
161+
162+
#[test]
163+
fn test_functions() -> Result<()> {
164+
test_function!(
165+
SubstrIndexFunc::new(),
166+
&[
167+
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
168+
ColumnarValue::Scalar(ScalarValue::from(".")),
169+
ColumnarValue::Scalar(ScalarValue::from(1i64)),
170+
],
171+
Ok(Some("www")),
172+
&str,
173+
Utf8,
174+
StringArray
175+
);
176+
test_function!(
177+
SubstrIndexFunc::new(),
178+
&[
179+
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
180+
ColumnarValue::Scalar(ScalarValue::from(".")),
181+
ColumnarValue::Scalar(ScalarValue::from(2i64)),
182+
],
183+
Ok(Some("www.apache")),
184+
&str,
185+
Utf8,
186+
StringArray
187+
);
188+
test_function!(
189+
SubstrIndexFunc::new(),
190+
&[
191+
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
192+
ColumnarValue::Scalar(ScalarValue::from(".")),
193+
ColumnarValue::Scalar(ScalarValue::from(-2i64)),
194+
],
195+
Ok(Some("apache.org")),
196+
&str,
197+
Utf8,
198+
StringArray
199+
);
200+
test_function!(
201+
SubstrIndexFunc::new(),
202+
&[
203+
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
204+
ColumnarValue::Scalar(ScalarValue::from(".")),
205+
ColumnarValue::Scalar(ScalarValue::from(-1i64)),
206+
],
207+
Ok(Some("org")),
208+
&str,
209+
Utf8,
210+
StringArray
211+
);
212+
test_function!(
213+
SubstrIndexFunc::new(),
214+
&[
215+
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
216+
ColumnarValue::Scalar(ScalarValue::from(".")),
217+
ColumnarValue::Scalar(ScalarValue::from(0i64)),
218+
],
219+
Ok(Some("")),
220+
&str,
221+
Utf8,
222+
StringArray
223+
);
224+
test_function!(
225+
SubstrIndexFunc::new(),
226+
&[
227+
ColumnarValue::Scalar(ScalarValue::from("")),
228+
ColumnarValue::Scalar(ScalarValue::from(".")),
229+
ColumnarValue::Scalar(ScalarValue::from(1i64)),
230+
],
231+
Ok(Some("")),
232+
&str,
233+
Utf8,
234+
StringArray
235+
);
236+
test_function!(
237+
SubstrIndexFunc::new(),
238+
&[
239+
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
240+
ColumnarValue::Scalar(ScalarValue::from("")),
241+
ColumnarValue::Scalar(ScalarValue::from(1i64)),
242+
],
243+
Ok(Some("")),
244+
&str,
245+
Utf8,
246+
StringArray
247+
);
248+
249+
Ok(())
250+
}
138251
}

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,8 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM
940940
(VALUES
941941
ROW('arrow.apache.org'),
942942
ROW('.'),
943-
ROW('...')
943+
ROW('...'),
944+
ROW(NULL)
944945
) AS strings(str),
945946
(VALUES
946947
ROW(1),
@@ -954,6 +955,14 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM
954955
) AS occurrences(n)
955956
ORDER BY str DESC, n;
956957
----
958+
NULL -100 NULL
959+
NULL -3 NULL
960+
NULL -2 NULL
961+
NULL -1 NULL
962+
NULL 1 NULL
963+
NULL 2 NULL
964+
NULL 3 NULL
965+
NULL 100 NULL
957966
arrow.apache.org -100 arrow.apache.org
958967
arrow.apache.org -3 arrow.apache.org
959968
arrow.apache.org -2 apache.org

0 commit comments

Comments
 (0)