2121
2222//! Regex expressions
2323
24- use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait } ;
24+ use arrow:: array:: {
25+ new_null_array, Array , ArrayRef , GenericStringArray , OffsetSizeTrait ,
26+ } ;
2527use arrow:: compute;
2628use datafusion_common:: { DataFusionError , Result } ;
29+ use datafusion_expr:: { ColumnarValue , ScalarFunctionImplementation } ;
2730use hashbrown:: HashMap ;
2831use lazy_static:: lazy_static;
2932use regex:: Regex ;
3033use std:: any:: type_name;
3134use std:: sync:: Arc ;
3235
33- macro_rules! downcast_string_arg {
36+ use crate :: functions:: make_scalar_function;
37+
38+ macro_rules! fetch_string_arg {
39+ ( $ARG: expr, $NAME: expr, $T: ident, $EARLY_ABORT: ident) => { {
40+ let array = downcast_string_array_arg!( $ARG, $NAME, $T) ;
41+ if array. is_null( 0 ) {
42+ return $EARLY_ABORT( array) ;
43+ } else {
44+ array. value( 0 )
45+ }
46+ } } ;
47+ }
48+
49+ macro_rules! downcast_string_array_arg {
3450 ( $ARG: expr, $NAME: expr, $T: ident) => { {
3551 $ARG. as_any( )
3652 . downcast_ref:: <GenericStringArray <T >>( )
@@ -48,14 +64,14 @@ macro_rules! downcast_string_arg {
4864pub fn regexp_match < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
4965 match args. len ( ) {
5066 2 => {
51- let values = downcast_string_arg ! ( args[ 0 ] , "string" , T ) ;
52- let regex = downcast_string_arg ! ( args[ 1 ] , "pattern" , T ) ;
67+ let values = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
68+ let regex = downcast_string_array_arg ! ( args[ 1 ] , "pattern" , T ) ;
5369 compute:: regexp_match ( values, regex, None ) . map_err ( DataFusionError :: ArrowError )
5470 }
5571 3 => {
56- let values = downcast_string_arg ! ( args[ 0 ] , "string" , T ) ;
57- let regex = downcast_string_arg ! ( args[ 1 ] , "pattern" , T ) ;
58- let flags = Some ( downcast_string_arg ! ( args[ 2 ] , "flags" , T ) ) ;
72+ let values = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
73+ let regex = downcast_string_array_arg ! ( args[ 1 ] , "pattern" , T ) ;
74+ let flags = Some ( downcast_string_array_arg ! ( args[ 2 ] , "flags" , T ) ) ;
5975 compute:: regexp_match ( values, regex, flags) . map_err ( DataFusionError :: ArrowError )
6076 }
6177 other => Err ( DataFusionError :: Internal ( format ! (
@@ -80,14 +96,17 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
8096///
8197/// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'`
8298pub fn regexp_replace < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
99+ // Default implementation for regexp_replace, assumes all args are arrays
100+ // and args is a sequence of 3 or 4 elements.
101+
83102 // creating Regex is expensive so create hashmap for memoization
84103 let mut patterns: HashMap < String , Regex > = HashMap :: new ( ) ;
85104
86105 match args. len ( ) {
87106 3 => {
88- let string_array = downcast_string_arg ! ( args[ 0 ] , "string" , T ) ;
89- let pattern_array = downcast_string_arg ! ( args[ 1 ] , "pattern" , T ) ;
90- let replacement_array = downcast_string_arg ! ( args[ 2 ] , "replacement" , T ) ;
107+ let string_array = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
108+ let pattern_array = downcast_string_array_arg ! ( args[ 1 ] , "pattern" , T ) ;
109+ let replacement_array = downcast_string_array_arg ! ( args[ 2 ] , "replacement" , T ) ;
91110
92111 let result = string_array
93112 . iter ( )
@@ -120,10 +139,10 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
120139 Ok ( Arc :: new ( result) as ArrayRef )
121140 }
122141 4 => {
123- let string_array = downcast_string_arg ! ( args[ 0 ] , "string" , T ) ;
124- let pattern_array = downcast_string_arg ! ( args[ 1 ] , "pattern" , T ) ;
125- let replacement_array = downcast_string_arg ! ( args[ 2 ] , "replacement" , T ) ;
126- let flags_array = downcast_string_arg ! ( args[ 3 ] , "flags" , T ) ;
142+ let string_array = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
143+ let pattern_array = downcast_string_array_arg ! ( args[ 1 ] , "pattern" , T ) ;
144+ let replacement_array = downcast_string_array_arg ! ( args[ 2 ] , "replacement" , T ) ;
145+ let flags_array = downcast_string_array_arg ! ( args[ 3 ] , "flags" , T ) ;
127146
128147 let result = string_array
129148 . iter ( )
@@ -178,10 +197,125 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
178197 }
179198}
180199
200+ fn _regexp_replace_early_abort < T : OffsetSizeTrait > (
201+ input_array : & GenericStringArray < T > ,
202+ ) -> Result < ArrayRef > {
203+ // Mimicking the existing behavior of regexp_replace, if any of the scalar arguments
204+ // are actuall null, then the result will be an array of the same size but with nulls.
205+ Ok ( new_null_array ( input_array. data_type ( ) , input_array. len ( ) ) )
206+ }
207+
208+ /// Special cased regex_replace implementation for the scenerio where
209+ /// the pattern, replacement and flags are static (arrays that are derived
210+ /// from scalars). This means we can skip regex caching system and basically
211+ /// hold a single Regex object for the replace operation. This also speeds
212+ /// up the pre-processing time of the replacement string, since it only
213+ /// needs to processed once.
214+ fn _regexp_replace_static_pattern_replace < T : OffsetSizeTrait > (
215+ args : & [ ArrayRef ] ,
216+ ) -> Result < ArrayRef > {
217+ let string_array = downcast_string_array_arg ! ( args[ 0 ] , "string" , T ) ;
218+ let pattern = fetch_string_arg ! ( args[ 1 ] , "pattern" , T , _regexp_replace_early_abort) ;
219+ let replacement =
220+ fetch_string_arg ! ( args[ 2 ] , "replacement" , T , _regexp_replace_early_abort) ;
221+ let flags = match args. len ( ) {
222+ 3 => None ,
223+ 4 => Some ( fetch_string_arg ! ( args[ 3 ] , "flags" , T , _regexp_replace_early_abort) ) ,
224+ other => {
225+ return Err ( DataFusionError :: Internal ( format ! (
226+ "regexp_replace was called with {} arguments. It requires at least 3 and at most 4." ,
227+ other
228+ ) ) )
229+ }
230+ } ;
231+
232+ // Embed the flag (if it exists) into the pattern. Limit will determine
233+ // whether this is a global match (as in replace all) or just a single
234+ // replace operation.
235+ let ( pattern, limit) = match flags {
236+ Some ( "g" ) => ( pattern. to_string ( ) , 0 ) ,
237+ Some ( flags) => (
238+ format ! ( "(?{}){}" , flags. to_string( ) . replace( 'g' , "" ) , pattern) ,
239+ !flags. contains ( 'g' ) as usize ,
240+ ) ,
241+ None => ( pattern. to_string ( ) , 1 ) ,
242+ } ;
243+
244+ let re = Regex :: new ( & pattern)
245+ . map_err ( |err| DataFusionError :: Execution ( err. to_string ( ) ) ) ?;
246+
247+ // Replaces the posix groups in the replacement string
248+ // with rust ones.
249+ let replacement = regex_replace_posix_groups ( replacement) ;
250+
251+ let result = string_array
252+ . iter ( )
253+ . map ( |string| {
254+ string. map ( |string| re. replacen ( string, limit, replacement. as_str ( ) ) )
255+ } )
256+ . collect :: < GenericStringArray < T > > ( ) ;
257+ Ok ( Arc :: new ( result) as ArrayRef )
258+ }
259+
260+ /// Determine which implementation of the regexp_replace to use based
261+ /// on the given set of arguments.
262+ pub fn specialize_regexp_replace < T : OffsetSizeTrait > (
263+ args : & [ ColumnarValue ] ,
264+ ) -> Result < ScalarFunctionImplementation > {
265+ // This will serve as a dispatch table where we can
266+ // leverage it in order to determine whether the scalarity
267+ // of the given set of arguments fits a better specialized
268+ // function.
269+ let ( is_source_scalar, is_pattern_scalar, is_replacement_scalar, is_flags_scalar) = (
270+ matches ! ( args[ 0 ] , ColumnarValue :: Scalar ( _) ) ,
271+ matches ! ( args[ 1 ] , ColumnarValue :: Scalar ( _) ) ,
272+ matches ! ( args[ 2 ] , ColumnarValue :: Scalar ( _) ) ,
273+ // The forth argument (flags) is optional; so in the event that
274+ // it is not available, we'll claim that it is scalar.
275+ matches ! ( args. get( 3 ) , Some ( ColumnarValue :: Scalar ( _) ) | None ) ,
276+ ) ;
277+
278+ match (
279+ is_source_scalar,
280+ is_pattern_scalar,
281+ is_replacement_scalar,
282+ is_flags_scalar,
283+ ) {
284+ // This represents a very hot path for the case where the there is
285+ // a single pattern that is being matched against and a single replacement.
286+ // This is extremely important to specialize on since it removes the overhead
287+ // of DF's in-house regex pattern cache (since there will be at most a single
288+ // pattern) and the pre-processing of the same replacement pattern at each
289+ // query.
290+ //
291+ // The flags needs to be a scalar as well since each pattern is actually
292+ // constructed with the flags embedded into the pattern itself. This means
293+ // even if the pattern itself is scalar, if the flags are an array then
294+ // we will create many regexes and it is best to use the implementation
295+ // that caches it. If there are no flags, we can simply ignore it here,
296+ // and let the specialized function handle it.
297+ ( _, true , true , true ) => {
298+ // We still don't know the scalarity of source, so we need the adapter
299+ // even if it will do some extra work for the pattern and the flags.
300+ //
301+ // TODO: maybe we need a way of telling the adapter on which arguments
302+ // it can skip filling (so that we won't create N - 1 redundant cols).
303+ Ok ( make_scalar_function (
304+ _regexp_replace_static_pattern_replace :: < T > ,
305+ ) )
306+ }
307+
308+ // If there are no specialized implementations, we'll fall back to the
309+ // generic implementation.
310+ ( _, _, _, _) => Ok ( make_scalar_function ( regexp_replace :: < T > ) ) ,
311+ }
312+ }
313+
181314#[ cfg( test) ]
182315mod tests {
183316 use super :: * ;
184317 use arrow:: array:: * ;
318+ use datafusion_common:: ScalarValue ;
185319
186320 #[ test]
187321 fn test_case_sensitive_regexp_match ( ) {
@@ -231,4 +365,130 @@ mod tests {
231365
232366 assert_eq ! ( re. as_ref( ) , & expected) ;
233367 }
368+
369+ #[ test]
370+ fn test_static_pattern_regexp_replace ( ) {
371+ let values = StringArray :: from ( vec ! [ "abc" ; 5 ] ) ;
372+ let patterns = StringArray :: from ( vec ! [ "b" ; 5 ] ) ;
373+ let replacements = StringArray :: from ( vec ! [ "foo" ; 5 ] ) ;
374+ let expected = StringArray :: from ( vec ! [ "afooc" ; 5 ] ) ;
375+
376+ let re = _regexp_replace_static_pattern_replace :: < i32 > ( & [
377+ Arc :: new ( values) ,
378+ Arc :: new ( patterns) ,
379+ Arc :: new ( replacements) ,
380+ ] )
381+ . unwrap ( ) ;
382+
383+ assert_eq ! ( re. as_ref( ) , & expected) ;
384+ }
385+
386+ #[ test]
387+ fn test_static_pattern_regexp_replace_with_flags ( ) {
388+ let values = StringArray :: from ( vec ! [ "abc" , "ABC" , "aBc" , "AbC" , "aBC" ] ) ;
389+ let patterns = StringArray :: from ( vec ! [ "b" ; 5 ] ) ;
390+ let replacements = StringArray :: from ( vec ! [ "foo" ; 5 ] ) ;
391+ let flags = StringArray :: from ( vec ! [ "i" ; 5 ] ) ;
392+ let expected =
393+ StringArray :: from ( vec ! [ "afooc" , "AfooC" , "afooc" , "AfooC" , "afooC" ] ) ;
394+
395+ let re = _regexp_replace_static_pattern_replace :: < i32 > ( & [
396+ Arc :: new ( values) ,
397+ Arc :: new ( patterns) ,
398+ Arc :: new ( replacements) ,
399+ Arc :: new ( flags) ,
400+ ] )
401+ . unwrap ( ) ;
402+
403+ assert_eq ! ( re. as_ref( ) , & expected) ;
404+ }
405+
406+ #[ test]
407+ fn test_static_pattern_regexp_replace_early_abort ( ) {
408+ let values = StringArray :: from ( vec ! [ "abc" ; 5 ] ) ;
409+ let patterns = StringArray :: from ( vec ! [ None ; 5 ] ) ;
410+ let replacements = StringArray :: from ( vec ! [ "foo" ; 5 ] ) ;
411+ let expected = StringArray :: from ( vec ! [ None ; 5 ] ) ;
412+
413+ let re = _regexp_replace_static_pattern_replace :: < i32 > ( & [
414+ Arc :: new ( values) ,
415+ Arc :: new ( patterns) ,
416+ Arc :: new ( replacements) ,
417+ ] )
418+ . unwrap ( ) ;
419+
420+ assert_eq ! ( re. as_ref( ) , & expected) ;
421+ }
422+
423+ #[ test]
424+ fn test_static_pattern_regexp_replace_early_abort_flags ( ) {
425+ let values = StringArray :: from ( vec ! [ "abc" ; 5 ] ) ;
426+ let patterns = StringArray :: from ( vec ! [ "a" ; 5 ] ) ;
427+ let replacements = StringArray :: from ( vec ! [ "foo" ; 5 ] ) ;
428+ let flags = StringArray :: from ( vec ! [ None ; 5 ] ) ;
429+ let expected = StringArray :: from ( vec ! [ None ; 5 ] ) ;
430+
431+ let re = _regexp_replace_static_pattern_replace :: < i32 > ( & [
432+ Arc :: new ( values) ,
433+ Arc :: new ( patterns) ,
434+ Arc :: new ( replacements) ,
435+ Arc :: new ( flags) ,
436+ ] )
437+ . unwrap ( ) ;
438+
439+ assert_eq ! ( re. as_ref( ) , & expected) ;
440+ }
441+
442+ #[ test]
443+ fn test_static_pattern_regexp_replace_pattern_error ( ) {
444+ let values = StringArray :: from ( vec ! [ "abc" ; 5 ] ) ;
445+ // Delibaretely using an invalid pattern to see how the single pattern
446+ // error is propagated on regexp_replace.
447+ let patterns = StringArray :: from ( vec ! [ "[" ; 5 ] ) ;
448+ let replacements = StringArray :: from ( vec ! [ "foo" ; 5 ] ) ;
449+
450+ let re = _regexp_replace_static_pattern_replace :: < i32 > ( & [
451+ Arc :: new ( values) ,
452+ Arc :: new ( patterns) ,
453+ Arc :: new ( replacements) ,
454+ ] ) ;
455+ let pattern_err = re. expect_err ( "broken pattern should have failed" ) ;
456+ assert_eq ! (
457+ pattern_err. to_string( ) ,
458+ "Execution error: regex parse error:\n [\n ^\n error: unclosed character class"
459+ ) ;
460+ }
461+
462+ #[ test]
463+ fn test_regexp_can_specialize_all_cases ( ) {
464+ macro_rules! make_scalar {
465+ ( ) => {
466+ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( "foo" . to_string( ) ) ) )
467+ } ;
468+ }
469+
470+ macro_rules! make_array {
471+ ( ) => {
472+ ColumnarValue :: Array (
473+ Arc :: new( StringArray :: from( vec![ "bar" ; 2 ] ) ) as ArrayRef
474+ )
475+ } ;
476+ }
477+
478+ for source in [ make_scalar ! ( ) , make_array ! ( ) ] {
479+ for pattern in [ make_scalar ! ( ) , make_array ! ( ) ] {
480+ for replacement in [ make_scalar ! ( ) , make_array ! ( ) ] {
481+ for flags in [ Some ( make_scalar ! ( ) ) , Some ( make_array ! ( ) ) , None ] {
482+ let mut args =
483+ vec ! [ source. clone( ) , pattern. clone( ) , replacement. clone( ) ] ;
484+ if let Some ( flags) = flags {
485+ args. push ( flags. clone ( ) ) ;
486+ }
487+ let regex_func = specialize_regexp_replace :: < i32 > ( & args) ;
488+ assert ! ( regex_func. is_ok( ) ) ;
489+ }
490+ }
491+ }
492+ }
493+ }
234494}
0 commit comments