Skip to content

Commit 15c19c3

Browse files
authored
Optimize regex_replace for scalar patterns (apache#3614)
* Optimize `regex_replace` for scalar patterns * Change the hot-path on `regexp_replace` to only variadic source (#2)
1 parent ea3dbb6 commit 15c19c3

File tree

2 files changed

+282
-20
lines changed

2 files changed

+282
-20
lines changed

datafusion/physical-expr/src/functions.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,20 +500,22 @@ pub fn create_physical_fun(
500500
BuiltinScalarFunction::RegexpReplace => {
501501
Arc::new(|args| match args[0].data_type() {
502502
DataType::Utf8 => {
503-
let func = invoke_if_regex_expressions_feature_flag!(
504-
regexp_replace,
503+
let specializer_func = invoke_if_regex_expressions_feature_flag!(
504+
specialize_regexp_replace,
505505
i32,
506506
"regexp_replace"
507507
);
508-
make_scalar_function(func)(args)
508+
let func = specializer_func(args)?;
509+
func(args)
509510
}
510511
DataType::LargeUtf8 => {
511-
let func = invoke_if_regex_expressions_feature_flag!(
512-
regexp_replace,
512+
let specializer_func = invoke_if_regex_expressions_feature_flag!(
513+
specialize_regexp_replace,
513514
i64,
514515
"regexp_replace"
515516
);
516-
make_scalar_function(func)(args)
517+
let func = specializer_func(args)?;
518+
func(args)
517519
}
518520
other => Err(DataFusionError::Internal(format!(
519521
"Unsupported data type {:?} for function regexp_replace",

datafusion/physical-expr/src/regex_expressions.rs

Lines changed: 274 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,32 @@
2121

2222
//! Regex expressions
2323
24-
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
24+
use arrow::array::{
25+
new_null_array, Array, ArrayRef, GenericStringArray, OffsetSizeTrait,
26+
};
2527
use arrow::compute;
2628
use datafusion_common::{DataFusionError, Result};
29+
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
2730
use hashbrown::HashMap;
2831
use lazy_static::lazy_static;
2932
use regex::Regex;
3033
use std::any::type_name;
3134
use std::sync::Arc;
3235

33-
macro_rules! downcast_string_arg {
36+
use crate::functions::make_scalar_function;
37+
38+
macro_rules! fetch_string_arg {
39+
($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident) => {{
40+
let array = downcast_string_array_arg!($ARG, $NAME, $T);
41+
if array.is_null(0) {
42+
return $EARLY_ABORT(array);
43+
} else {
44+
array.value(0)
45+
}
46+
}};
47+
}
48+
49+
macro_rules! downcast_string_array_arg {
3450
($ARG:expr, $NAME:expr, $T:ident) => {{
3551
$ARG.as_any()
3652
.downcast_ref::<GenericStringArray<T>>()
@@ -48,14 +64,14 @@ macro_rules! downcast_string_arg {
4864
pub fn regexp_match<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
4965
match args.len() {
5066
2 => {
51-
let values = downcast_string_arg!(args[0], "string", T);
52-
let regex = downcast_string_arg!(args[1], "pattern", T);
67+
let values = downcast_string_array_arg!(args[0], "string", T);
68+
let regex = downcast_string_array_arg!(args[1], "pattern", T);
5369
compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError)
5470
}
5571
3 => {
56-
let values = downcast_string_arg!(args[0], "string", T);
57-
let regex = downcast_string_arg!(args[1], "pattern", T);
58-
let flags = Some(downcast_string_arg!(args[2], "flags", T));
72+
let values = downcast_string_array_arg!(args[0], "string", T);
73+
let regex = downcast_string_array_arg!(args[1], "pattern", T);
74+
let flags = Some(downcast_string_array_arg!(args[2], "flags", T));
5975
compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError)
6076
}
6177
other => Err(DataFusionError::Internal(format!(
@@ -80,14 +96,17 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
8096
///
8197
/// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'`
8298
pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
99+
// Default implementation for regexp_replace, assumes all args are arrays
100+
// and args is a sequence of 3 or 4 elements.
101+
83102
// creating Regex is expensive so create hashmap for memoization
84103
let mut patterns: HashMap<String, Regex> = HashMap::new();
85104

86105
match args.len() {
87106
3 => {
88-
let string_array = downcast_string_arg!(args[0], "string", T);
89-
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
90-
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
107+
let string_array = downcast_string_array_arg!(args[0], "string", T);
108+
let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
109+
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
91110

92111
let result = string_array
93112
.iter()
@@ -120,10 +139,10 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
120139
Ok(Arc::new(result) as ArrayRef)
121140
}
122141
4 => {
123-
let string_array = downcast_string_arg!(args[0], "string", T);
124-
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
125-
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
126-
let flags_array = downcast_string_arg!(args[3], "flags", T);
142+
let string_array = downcast_string_array_arg!(args[0], "string", T);
143+
let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
144+
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
145+
let flags_array = downcast_string_array_arg!(args[3], "flags", T);
127146

128147
let result = string_array
129148
.iter()
@@ -178,10 +197,125 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
178197
}
179198
}
180199

200+
fn _regexp_replace_early_abort<T: OffsetSizeTrait>(
201+
input_array: &GenericStringArray<T>,
202+
) -> Result<ArrayRef> {
203+
// Mimicking the existing behavior of regexp_replace, if any of the scalar arguments
204+
// are actuall null, then the result will be an array of the same size but with nulls.
205+
Ok(new_null_array(input_array.data_type(), input_array.len()))
206+
}
207+
208+
/// Special cased regex_replace implementation for the scenerio where
209+
/// the pattern, replacement and flags are static (arrays that are derived
210+
/// from scalars). This means we can skip regex caching system and basically
211+
/// hold a single Regex object for the replace operation. This also speeds
212+
/// up the pre-processing time of the replacement string, since it only
213+
/// needs to processed once.
214+
fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
215+
args: &[ArrayRef],
216+
) -> Result<ArrayRef> {
217+
let string_array = downcast_string_array_arg!(args[0], "string", T);
218+
let pattern = fetch_string_arg!(args[1], "pattern", T, _regexp_replace_early_abort);
219+
let replacement =
220+
fetch_string_arg!(args[2], "replacement", T, _regexp_replace_early_abort);
221+
let flags = match args.len() {
222+
3 => None,
223+
4 => Some(fetch_string_arg!(args[3], "flags", T, _regexp_replace_early_abort)),
224+
other => {
225+
return Err(DataFusionError::Internal(format!(
226+
"regexp_replace was called with {} arguments. It requires at least 3 and at most 4.",
227+
other
228+
)))
229+
}
230+
};
231+
232+
// Embed the flag (if it exists) into the pattern. Limit will determine
233+
// whether this is a global match (as in replace all) or just a single
234+
// replace operation.
235+
let (pattern, limit) = match flags {
236+
Some("g") => (pattern.to_string(), 0),
237+
Some(flags) => (
238+
format!("(?{}){}", flags.to_string().replace('g', ""), pattern),
239+
!flags.contains('g') as usize,
240+
),
241+
None => (pattern.to_string(), 1),
242+
};
243+
244+
let re = Regex::new(&pattern)
245+
.map_err(|err| DataFusionError::Execution(err.to_string()))?;
246+
247+
// Replaces the posix groups in the replacement string
248+
// with rust ones.
249+
let replacement = regex_replace_posix_groups(replacement);
250+
251+
let result = string_array
252+
.iter()
253+
.map(|string| {
254+
string.map(|string| re.replacen(string, limit, replacement.as_str()))
255+
})
256+
.collect::<GenericStringArray<T>>();
257+
Ok(Arc::new(result) as ArrayRef)
258+
}
259+
260+
/// Determine which implementation of the regexp_replace to use based
261+
/// on the given set of arguments.
262+
pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
263+
args: &[ColumnarValue],
264+
) -> Result<ScalarFunctionImplementation> {
265+
// This will serve as a dispatch table where we can
266+
// leverage it in order to determine whether the scalarity
267+
// of the given set of arguments fits a better specialized
268+
// function.
269+
let (is_source_scalar, is_pattern_scalar, is_replacement_scalar, is_flags_scalar) = (
270+
matches!(args[0], ColumnarValue::Scalar(_)),
271+
matches!(args[1], ColumnarValue::Scalar(_)),
272+
matches!(args[2], ColumnarValue::Scalar(_)),
273+
// The forth argument (flags) is optional; so in the event that
274+
// it is not available, we'll claim that it is scalar.
275+
matches!(args.get(3), Some(ColumnarValue::Scalar(_)) | None),
276+
);
277+
278+
match (
279+
is_source_scalar,
280+
is_pattern_scalar,
281+
is_replacement_scalar,
282+
is_flags_scalar,
283+
) {
284+
// This represents a very hot path for the case where the there is
285+
// a single pattern that is being matched against and a single replacement.
286+
// This is extremely important to specialize on since it removes the overhead
287+
// of DF's in-house regex pattern cache (since there will be at most a single
288+
// pattern) and the pre-processing of the same replacement pattern at each
289+
// query.
290+
//
291+
// The flags needs to be a scalar as well since each pattern is actually
292+
// constructed with the flags embedded into the pattern itself. This means
293+
// even if the pattern itself is scalar, if the flags are an array then
294+
// we will create many regexes and it is best to use the implementation
295+
// that caches it. If there are no flags, we can simply ignore it here,
296+
// and let the specialized function handle it.
297+
(_, true, true, true) => {
298+
// We still don't know the scalarity of source, so we need the adapter
299+
// even if it will do some extra work for the pattern and the flags.
300+
//
301+
// TODO: maybe we need a way of telling the adapter on which arguments
302+
// it can skip filling (so that we won't create N - 1 redundant cols).
303+
Ok(make_scalar_function(
304+
_regexp_replace_static_pattern_replace::<T>,
305+
))
306+
}
307+
308+
// If there are no specialized implementations, we'll fall back to the
309+
// generic implementation.
310+
(_, _, _, _) => Ok(make_scalar_function(regexp_replace::<T>)),
311+
}
312+
}
313+
181314
#[cfg(test)]
182315
mod tests {
183316
use super::*;
184317
use arrow::array::*;
318+
use datafusion_common::ScalarValue;
185319

186320
#[test]
187321
fn test_case_sensitive_regexp_match() {
@@ -231,4 +365,130 @@ mod tests {
231365

232366
assert_eq!(re.as_ref(), &expected);
233367
}
368+
369+
#[test]
370+
fn test_static_pattern_regexp_replace() {
371+
let values = StringArray::from(vec!["abc"; 5]);
372+
let patterns = StringArray::from(vec!["b"; 5]);
373+
let replacements = StringArray::from(vec!["foo"; 5]);
374+
let expected = StringArray::from(vec!["afooc"; 5]);
375+
376+
let re = _regexp_replace_static_pattern_replace::<i32>(&[
377+
Arc::new(values),
378+
Arc::new(patterns),
379+
Arc::new(replacements),
380+
])
381+
.unwrap();
382+
383+
assert_eq!(re.as_ref(), &expected);
384+
}
385+
386+
#[test]
387+
fn test_static_pattern_regexp_replace_with_flags() {
388+
let values = StringArray::from(vec!["abc", "ABC", "aBc", "AbC", "aBC"]);
389+
let patterns = StringArray::from(vec!["b"; 5]);
390+
let replacements = StringArray::from(vec!["foo"; 5]);
391+
let flags = StringArray::from(vec!["i"; 5]);
392+
let expected =
393+
StringArray::from(vec!["afooc", "AfooC", "afooc", "AfooC", "afooC"]);
394+
395+
let re = _regexp_replace_static_pattern_replace::<i32>(&[
396+
Arc::new(values),
397+
Arc::new(patterns),
398+
Arc::new(replacements),
399+
Arc::new(flags),
400+
])
401+
.unwrap();
402+
403+
assert_eq!(re.as_ref(), &expected);
404+
}
405+
406+
#[test]
407+
fn test_static_pattern_regexp_replace_early_abort() {
408+
let values = StringArray::from(vec!["abc"; 5]);
409+
let patterns = StringArray::from(vec![None; 5]);
410+
let replacements = StringArray::from(vec!["foo"; 5]);
411+
let expected = StringArray::from(vec![None; 5]);
412+
413+
let re = _regexp_replace_static_pattern_replace::<i32>(&[
414+
Arc::new(values),
415+
Arc::new(patterns),
416+
Arc::new(replacements),
417+
])
418+
.unwrap();
419+
420+
assert_eq!(re.as_ref(), &expected);
421+
}
422+
423+
#[test]
424+
fn test_static_pattern_regexp_replace_early_abort_flags() {
425+
let values = StringArray::from(vec!["abc"; 5]);
426+
let patterns = StringArray::from(vec!["a"; 5]);
427+
let replacements = StringArray::from(vec!["foo"; 5]);
428+
let flags = StringArray::from(vec![None; 5]);
429+
let expected = StringArray::from(vec![None; 5]);
430+
431+
let re = _regexp_replace_static_pattern_replace::<i32>(&[
432+
Arc::new(values),
433+
Arc::new(patterns),
434+
Arc::new(replacements),
435+
Arc::new(flags),
436+
])
437+
.unwrap();
438+
439+
assert_eq!(re.as_ref(), &expected);
440+
}
441+
442+
#[test]
443+
fn test_static_pattern_regexp_replace_pattern_error() {
444+
let values = StringArray::from(vec!["abc"; 5]);
445+
// Delibaretely using an invalid pattern to see how the single pattern
446+
// error is propagated on regexp_replace.
447+
let patterns = StringArray::from(vec!["["; 5]);
448+
let replacements = StringArray::from(vec!["foo"; 5]);
449+
450+
let re = _regexp_replace_static_pattern_replace::<i32>(&[
451+
Arc::new(values),
452+
Arc::new(patterns),
453+
Arc::new(replacements),
454+
]);
455+
let pattern_err = re.expect_err("broken pattern should have failed");
456+
assert_eq!(
457+
pattern_err.to_string(),
458+
"Execution error: regex parse error:\n [\n ^\nerror: unclosed character class"
459+
);
460+
}
461+
462+
#[test]
463+
fn test_regexp_can_specialize_all_cases() {
464+
macro_rules! make_scalar {
465+
() => {
466+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("foo".to_string())))
467+
};
468+
}
469+
470+
macro_rules! make_array {
471+
() => {
472+
ColumnarValue::Array(
473+
Arc::new(StringArray::from(vec!["bar"; 2])) as ArrayRef
474+
)
475+
};
476+
}
477+
478+
for source in [make_scalar!(), make_array!()] {
479+
for pattern in [make_scalar!(), make_array!()] {
480+
for replacement in [make_scalar!(), make_array!()] {
481+
for flags in [Some(make_scalar!()), Some(make_array!()), None] {
482+
let mut args =
483+
vec![source.clone(), pattern.clone(), replacement.clone()];
484+
if let Some(flags) = flags {
485+
args.push(flags.clone());
486+
}
487+
let regex_func = specialize_regexp_replace::<i32>(&args);
488+
assert!(regex_func.is_ok());
489+
}
490+
}
491+
}
492+
}
493+
}
234494
}

0 commit comments

Comments
 (0)