@@ -22,18 +22,19 @@ use arrow::array::*;
2222use arrow:: compute:: kernels:: arithmetic:: {
2323 add, divide, divide_scalar, modulus, modulus_scalar, multiply, subtract,
2424} ;
25- use arrow:: compute:: kernels:: boolean:: { and_kleene, or_kleene} ;
25+ use arrow:: compute:: kernels:: boolean:: { and_kleene, not , or_kleene} ;
2626use arrow:: compute:: kernels:: comparison:: { eq, gt, gt_eq, lt, lt_eq, neq} ;
2727use arrow:: compute:: kernels:: comparison:: {
2828 eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
2929} ;
3030use arrow:: compute:: kernels:: comparison:: {
31- eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, like_utf8_scalar , lt_eq_utf8, lt_utf8,
32- neq_utf8 , nlike_utf8 , nlike_utf8_scalar ,
31+ eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8 , nlike_utf8 ,
32+ regexp_is_match_utf8 ,
3333} ;
3434use arrow:: compute:: kernels:: comparison:: {
35- eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, lt_eq_utf8_scalar, lt_utf8_scalar,
36- neq_utf8_scalar,
35+ eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, like_utf8_scalar,
36+ lt_eq_utf8_scalar, lt_utf8_scalar, neq_utf8_scalar, nlike_utf8_scalar,
37+ regexp_is_match_utf8_scalar,
3738} ;
3839use arrow:: datatypes:: { DataType , Schema , TimeUnit } ;
3940use arrow:: record_batch:: RecordBatch ;
@@ -44,7 +45,9 @@ use crate::physical_plan::expressions::try_cast;
4445use crate :: physical_plan:: { ColumnarValue , PhysicalExpr } ;
4546use crate :: scalar:: ScalarValue ;
4647
47- use super :: coercion:: { eq_coercion, like_coercion, numerical_coercion, order_coercion} ;
48+ use super :: coercion:: {
49+ eq_coercion, like_coercion, numerical_coercion, order_coercion, string_coercion,
50+ } ;
4851
4952/// Binary expression
5053#[ derive( Debug ) ]
@@ -339,6 +342,91 @@ macro_rules! boolean_op {
339342 } } ;
340343}
341344
345+ macro_rules! binary_string_array_flag_op {
346+ ( $LEFT: expr, $RIGHT: expr, $OP: ident, $NOT: expr, $FLAG: expr) => { {
347+ match $LEFT. data_type( ) {
348+ DataType :: Utf8 => {
349+ compute_utf8_flag_op!( $LEFT, $RIGHT, $OP, StringArray , $NOT, $FLAG)
350+ }
351+ DataType :: LargeUtf8 => {
352+ compute_utf8_flag_op!( $LEFT, $RIGHT, $OP, LargeStringArray , $NOT, $FLAG)
353+ }
354+ other => Err ( DataFusionError :: Internal ( format!(
355+ "Data type {:?} not supported for binary_string_array_flag_op operation on string array" ,
356+ other
357+ ) ) ) ,
358+ }
359+ } } ;
360+ }
361+
362+ /// Invoke a compute kernel on a pair of binary data arrays with flags
363+ macro_rules! compute_utf8_flag_op {
364+ ( $LEFT: expr, $RIGHT: expr, $OP: ident, $ARRAYTYPE: ident, $NOT: expr, $FLAG: expr) => { {
365+ let ll = $LEFT
366+ . as_any( )
367+ . downcast_ref:: <$ARRAYTYPE>( )
368+ . expect( "compute_utf8_flag_op failed to downcast array" ) ;
369+ let rr = $RIGHT
370+ . as_any( )
371+ . downcast_ref:: <$ARRAYTYPE>( )
372+ . expect( "compute_utf8_flag_op failed to downcast array" ) ;
373+
374+ let flag = if $FLAG {
375+ Some ( $ARRAYTYPE:: from( vec![ "i" ; ll. len( ) ] ) )
376+ } else {
377+ None
378+ } ;
379+ let mut array = paste:: expr! { [ <$OP _utf8>] } ( & ll, & rr, flag. as_ref( ) ) ?;
380+ if $NOT {
381+ array = not( & array) . unwrap( ) ;
382+ }
383+ Ok ( Arc :: new( array) )
384+ } } ;
385+ }
386+
387+ macro_rules! binary_string_array_flag_op_scalar {
388+ ( $LEFT: expr, $RIGHT: expr, $OP: ident, $NOT: expr, $FLAG: expr) => { {
389+ let result: Result <Arc <dyn Array >> = match $LEFT. data_type( ) {
390+ DataType :: Utf8 => {
391+ compute_utf8_flag_op_scalar!( $LEFT, $RIGHT, $OP, StringArray , $NOT, $FLAG)
392+ }
393+ DataType :: LargeUtf8 => {
394+ compute_utf8_flag_op_scalar!( $LEFT, $RIGHT, $OP, LargeStringArray , $NOT, $FLAG)
395+ }
396+ other => Err ( DataFusionError :: Internal ( format!(
397+ "Data type {:?} not supported for binary_string_array_flag_op_scalar operation on string array" ,
398+ other
399+ ) ) ) ,
400+ } ;
401+ Some ( result)
402+ } } ;
403+ }
404+
405+ /// Invoke a compute kernel on a data array and a scalar value with flag
406+ macro_rules! compute_utf8_flag_op_scalar {
407+ ( $LEFT: expr, $RIGHT: expr, $OP: ident, $ARRAYTYPE: ident, $NOT: expr, $FLAG: expr) => { {
408+ let ll = $LEFT
409+ . as_any( )
410+ . downcast_ref:: <$ARRAYTYPE>( )
411+ . expect( "compute_utf8_flag_op_scalar failed to downcast array" ) ;
412+
413+ if let ScalarValue :: Utf8 ( Some ( string_value) ) = $RIGHT {
414+ let flag = if $FLAG { Some ( "i" ) } else { None } ;
415+ let mut array =
416+ paste:: expr! { [ <$OP _utf8_scalar>] } ( & ll, & string_value, flag) ?;
417+ if $NOT {
418+ array = not( & array) . unwrap( ) ;
419+ }
420+ Ok ( Arc :: new( array) )
421+ } else {
422+ Err ( DataFusionError :: Internal ( format!(
423+ "compute_utf8_flag_op_scalar failed to cast literal value {}" ,
424+ $RIGHT
425+ ) ) )
426+ }
427+ } } ;
428+ }
429+
342430/// Coercion rules for all binary operators. Returns the output type
343431/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
344432fn common_binary_type (
@@ -368,6 +456,10 @@ fn common_binary_type(
368456 | Operator :: Modulo
369457 | Operator :: Divide
370458 | Operator :: Multiply => numerical_coercion ( lhs_type, rhs_type) ,
459+ Operator :: RegexMatch
460+ | Operator :: RegexIMatch
461+ | Operator :: RegexNotMatch
462+ | Operator :: RegexNotIMatch => string_coercion ( lhs_type, rhs_type) ,
371463 } ;
372464
373465 // re-write the error message of failed coercions to include the operator's information
@@ -406,7 +498,11 @@ pub fn binary_operator_data_type(
406498 | Operator :: Lt
407499 | Operator :: Gt
408500 | Operator :: GtEq
409- | Operator :: LtEq => Ok ( DataType :: Boolean ) ,
501+ | Operator :: LtEq
502+ | Operator :: RegexMatch
503+ | Operator :: RegexIMatch
504+ | Operator :: RegexNotMatch
505+ | Operator :: RegexNotIMatch => Ok ( DataType :: Boolean ) ,
410506 // math operations return the same value as the common coerced type
411507 Operator :: Plus
412508 | Operator :: Minus
@@ -475,6 +571,34 @@ impl PhysicalExpr for BinaryExpr {
475571 Operator :: Modulo => {
476572 binary_primitive_array_op_scalar ! ( array, scalar. clone( ) , modulus)
477573 }
574+ Operator :: RegexMatch => binary_string_array_flag_op_scalar ! (
575+ array,
576+ scalar. clone( ) ,
577+ regexp_is_match,
578+ false ,
579+ false
580+ ) ,
581+ Operator :: RegexIMatch => binary_string_array_flag_op_scalar ! (
582+ array,
583+ scalar. clone( ) ,
584+ regexp_is_match,
585+ false ,
586+ true
587+ ) ,
588+ Operator :: RegexNotMatch => binary_string_array_flag_op_scalar ! (
589+ array,
590+ scalar. clone( ) ,
591+ regexp_is_match,
592+ true ,
593+ false
594+ ) ,
595+ Operator :: RegexNotIMatch => binary_string_array_flag_op_scalar ! (
596+ array,
597+ scalar. clone( ) ,
598+ regexp_is_match,
599+ true ,
600+ true
601+ ) ,
478602 // if scalar operation is not supported - fallback to array implementation
479603 _ => None ,
480604 }
@@ -547,6 +671,18 @@ impl PhysicalExpr for BinaryExpr {
547671 ) ) ) ;
548672 }
549673 }
674+ Operator :: RegexMatch => {
675+ binary_string_array_flag_op ! ( left, right, regexp_is_match, false , false )
676+ }
677+ Operator :: RegexIMatch => {
678+ binary_string_array_flag_op ! ( left, right, regexp_is_match, false , true )
679+ }
680+ Operator :: RegexNotMatch => {
681+ binary_string_array_flag_op ! ( left, right, regexp_is_match, true , false )
682+ }
683+ Operator :: RegexNotIMatch => {
684+ binary_string_array_flag_op ! ( left, right, regexp_is_match, true , true )
685+ }
550686 } ;
551687 result. map ( |a| ColumnarValue :: Array ( a) )
552688 }
@@ -822,6 +958,102 @@ mod tests {
822958 DataType :: Boolean ,
823959 vec![ true , false ]
824960 ) ;
961+ test_coercion ! (
962+ StringArray ,
963+ DataType :: Utf8 ,
964+ vec![ "abc" ; 5 ] ,
965+ StringArray ,
966+ DataType :: Utf8 ,
967+ vec![ "^a" , "^A" , "(b|d)" , "(B|D)" , "^(b|c)" ] ,
968+ Operator :: RegexMatch ,
969+ BooleanArray ,
970+ DataType :: Boolean ,
971+ vec![ true , false , true , false , false ]
972+ ) ;
973+ test_coercion ! (
974+ StringArray ,
975+ DataType :: Utf8 ,
976+ vec![ "abc" ; 5 ] ,
977+ StringArray ,
978+ DataType :: Utf8 ,
979+ vec![ "^a" , "^A" , "(b|d)" , "(B|D)" , "^(b|c)" ] ,
980+ Operator :: RegexIMatch ,
981+ BooleanArray ,
982+ DataType :: Boolean ,
983+ vec![ true , true , true , true , false ]
984+ ) ;
985+ test_coercion ! (
986+ StringArray ,
987+ DataType :: Utf8 ,
988+ vec![ "abc" ; 5 ] ,
989+ StringArray ,
990+ DataType :: Utf8 ,
991+ vec![ "^a" , "^A" , "(b|d)" , "(B|D)" , "^(b|c)" ] ,
992+ Operator :: RegexNotMatch ,
993+ BooleanArray ,
994+ DataType :: Boolean ,
995+ vec![ false , true , false , true , true ]
996+ ) ;
997+ test_coercion ! (
998+ StringArray ,
999+ DataType :: Utf8 ,
1000+ vec![ "abc" ; 5 ] ,
1001+ StringArray ,
1002+ DataType :: Utf8 ,
1003+ vec![ "^a" , "^A" , "(b|d)" , "(B|D)" , "^(b|c)" ] ,
1004+ Operator :: RegexNotIMatch ,
1005+ BooleanArray ,
1006+ DataType :: Boolean ,
1007+ vec![ false , false , false , false , true ]
1008+ ) ;
1009+ test_coercion ! (
1010+ LargeStringArray ,
1011+ DataType :: LargeUtf8 ,
1012+ vec![ "abc" ; 5 ] ,
1013+ LargeStringArray ,
1014+ DataType :: LargeUtf8 ,
1015+ vec![ "^a" , "^A" , "(b|d)" , "(B|D)" , "^(b|c)" ] ,
1016+ Operator :: RegexMatch ,
1017+ BooleanArray ,
1018+ DataType :: Boolean ,
1019+ vec![ true , false , true , false , false ]
1020+ ) ;
1021+ test_coercion ! (
1022+ LargeStringArray ,
1023+ DataType :: LargeUtf8 ,
1024+ vec![ "abc" ; 5 ] ,
1025+ LargeStringArray ,
1026+ DataType :: LargeUtf8 ,
1027+ vec![ "^a" , "^A" , "(b|d)" , "(B|D)" , "^(b|c)" ] ,
1028+ Operator :: RegexIMatch ,
1029+ BooleanArray ,
1030+ DataType :: Boolean ,
1031+ vec![ true , true , true , true , false ]
1032+ ) ;
1033+ test_coercion ! (
1034+ LargeStringArray ,
1035+ DataType :: LargeUtf8 ,
1036+ vec![ "abc" ; 5 ] ,
1037+ LargeStringArray ,
1038+ DataType :: LargeUtf8 ,
1039+ vec![ "^a" , "^A" , "(b|d)" , "(B|D)" , "^(b|c)" ] ,
1040+ Operator :: RegexNotMatch ,
1041+ BooleanArray ,
1042+ DataType :: Boolean ,
1043+ vec![ false , true , false , true , true ]
1044+ ) ;
1045+ test_coercion ! (
1046+ LargeStringArray ,
1047+ DataType :: LargeUtf8 ,
1048+ vec![ "abc" ; 5 ] ,
1049+ LargeStringArray ,
1050+ DataType :: LargeUtf8 ,
1051+ vec![ "^a" , "^A" , "(b|d)" , "(B|D)" , "^(b|c)" ] ,
1052+ Operator :: RegexNotIMatch ,
1053+ BooleanArray ,
1054+ DataType :: Boolean ,
1055+ vec![ false , false , false , false , true ]
1056+ ) ;
8251057 Ok ( ( ) )
8261058 }
8271059
0 commit comments