Skip to content

Commit dd5f936

Browse files
authored
Add support for PostgreSQL regex match (#870)
1 parent eef5e2d commit dd5f936

File tree

10 files changed

+325
-15
lines changed

10 files changed

+325
-15
lines changed

ballista-examples/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ edition = "2018"
2828
publish = false
2929

3030
[dependencies]
31-
arrow-flight = { version = "^5.2" }
31+
arrow-flight = { version = "^5.3" }
3232
datafusion = { path = "../datafusion" }
3333
ballista = { path = "../ballista/rust/client" }
3434
prost = "0.8"

ballista/rust/core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ tokio = "1.0"
4242
tonic = "0.5"
4343
uuid = { version = "0.8", features = ["v4"] }
4444

45-
arrow-flight = { version = "^5.2" }
45+
arrow-flight = { version = "^5.3" }
4646

4747
datafusion = { path = "../../../datafusion", version = "5.1.0" }
4848

ballista/rust/executor/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ edition = "2018"
2929
snmalloc = ["snmalloc-rs"]
3030

3131
[dependencies]
32-
arrow = { version = "^5.2" }
33-
arrow-flight = { version = "^5.2" }
32+
arrow = { version = "^5.3" }
33+
arrow-flight = { version = "^5.3" }
3434
anyhow = "1"
3535
async-trait = "0.1.36"
3636
ballista-core = { path = "../core", version = "0.6.0" }

datafusion-cli/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ clap = "2.33"
3131
rustyline = "8.0"
3232
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
3333
datafusion = { path = "../datafusion", version = "5.1.0" }
34-
arrow = { version = "^5.2" }
34+
arrow = { version = "^5.3" }
3535
ballista = { path = "../ballista/rust/client", version = "0.6.0" }

datafusion-examples/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ publish = false
2929

3030

3131
[dev-dependencies]
32-
arrow-flight = { version = "^5.2" }
32+
arrow-flight = { version = "^5.3" }
3333
datafusion = { path = "../datafusion" }
3434
prost = "0.8"
3535
tonic = "0.5"

datafusion/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ force_hash_collisions = []
4949
[dependencies]
5050
ahash = "0.7"
5151
hashbrown = { version = "0.11", features = ["raw"] }
52-
arrow = { version = "^5.2", features = ["prettyprint"] }
53-
parquet = { version = "^5.2", features = ["arrow"] }
52+
arrow = { version = "^5.3", features = ["prettyprint"] }
53+
parquet = { version = "^5.3", features = ["arrow"] }
5454
sqlparser = "0.10"
5555
paste = "^1.0"
5656
num_cpus = "1.13.0"

datafusion/src/logical_plan/operators.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ pub enum Operator {
5252
Like,
5353
/// Does not match a wildcard pattern
5454
NotLike,
55+
/// Case sensitive regex match
56+
RegexMatch,
57+
/// Case insensitive regex match
58+
RegexIMatch,
59+
/// Case sensitive regex not match
60+
RegexNotMatch,
61+
/// Case insensitive regex not match
62+
RegexNotIMatch,
5563
}
5664

5765
impl fmt::Display for Operator {
@@ -72,6 +80,10 @@ impl fmt::Display for Operator {
7280
Operator::Or => "OR",
7381
Operator::Like => "LIKE",
7482
Operator::NotLike => "NOT LIKE",
83+
Operator::RegexMatch => "~",
84+
Operator::RegexIMatch => "~*",
85+
Operator::RegexNotMatch => "!~",
86+
Operator::RegexNotIMatch => "!~*",
7587
};
7688
write!(f, "{}", display)
7789
}

datafusion/src/physical_plan/expressions/binary.rs

Lines changed: 239 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,19 @@ use arrow::array::*;
2222
use arrow::compute::kernels::arithmetic::{
2323
add, divide, divide_scalar, modulus, modulus_scalar, multiply, subtract,
2424
};
25-
use arrow::compute::kernels::boolean::{and_kleene, or_kleene};
25+
use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
2626
use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
2727
use arrow::compute::kernels::comparison::{
2828
eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
2929
};
3030
use arrow::compute::kernels::comparison::{
31-
eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, like_utf8_scalar, lt_eq_utf8, lt_utf8,
32-
neq_utf8, nlike_utf8, nlike_utf8_scalar,
31+
eq_utf8, gt_eq_utf8, gt_utf8, like_utf8, lt_eq_utf8, lt_utf8, neq_utf8, nlike_utf8,
32+
regexp_is_match_utf8,
3333
};
3434
use arrow::compute::kernels::comparison::{
35-
eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, lt_eq_utf8_scalar, lt_utf8_scalar,
36-
neq_utf8_scalar,
35+
eq_utf8_scalar, gt_eq_utf8_scalar, gt_utf8_scalar, like_utf8_scalar,
36+
lt_eq_utf8_scalar, lt_utf8_scalar, neq_utf8_scalar, nlike_utf8_scalar,
37+
regexp_is_match_utf8_scalar,
3738
};
3839
use arrow::datatypes::{DataType, Schema, TimeUnit};
3940
use arrow::record_batch::RecordBatch;
@@ -44,7 +45,9 @@ use crate::physical_plan::expressions::try_cast;
4445
use crate::physical_plan::{ColumnarValue, PhysicalExpr};
4546
use crate::scalar::ScalarValue;
4647

47-
use super::coercion::{eq_coercion, like_coercion, numerical_coercion, order_coercion};
48+
use super::coercion::{
49+
eq_coercion, like_coercion, numerical_coercion, order_coercion, string_coercion,
50+
};
4851

4952
/// Binary expression
5053
#[derive(Debug)]
@@ -339,6 +342,91 @@ macro_rules! boolean_op {
339342
}};
340343
}
341344

345+
macro_rules! binary_string_array_flag_op {
346+
($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
347+
match $LEFT.data_type() {
348+
DataType::Utf8 => {
349+
compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
350+
}
351+
DataType::LargeUtf8 => {
352+
compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
353+
}
354+
other => Err(DataFusionError::Internal(format!(
355+
"Data type {:?} not supported for binary_string_array_flag_op operation on string array",
356+
other
357+
))),
358+
}
359+
}};
360+
}
361+
362+
/// Invoke a compute kernel on a pair of binary data arrays with flags
363+
macro_rules! compute_utf8_flag_op {
364+
($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
365+
let ll = $LEFT
366+
.as_any()
367+
.downcast_ref::<$ARRAYTYPE>()
368+
.expect("compute_utf8_flag_op failed to downcast array");
369+
let rr = $RIGHT
370+
.as_any()
371+
.downcast_ref::<$ARRAYTYPE>()
372+
.expect("compute_utf8_flag_op failed to downcast array");
373+
374+
let flag = if $FLAG {
375+
Some($ARRAYTYPE::from(vec!["i"; ll.len()]))
376+
} else {
377+
None
378+
};
379+
let mut array = paste::expr! {[<$OP _utf8>]}(&ll, &rr, flag.as_ref())?;
380+
if $NOT {
381+
array = not(&array).unwrap();
382+
}
383+
Ok(Arc::new(array))
384+
}};
385+
}
386+
387+
macro_rules! binary_string_array_flag_op_scalar {
388+
($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
389+
let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
390+
DataType::Utf8 => {
391+
compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG)
392+
}
393+
DataType::LargeUtf8 => {
394+
compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG)
395+
}
396+
other => Err(DataFusionError::Internal(format!(
397+
"Data type {:?} not supported for binary_string_array_flag_op_scalar operation on string array",
398+
other
399+
))),
400+
};
401+
Some(result)
402+
}};
403+
}
404+
405+
/// Invoke a compute kernel on a data array and a scalar value with flag
406+
macro_rules! compute_utf8_flag_op_scalar {
407+
($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{
408+
let ll = $LEFT
409+
.as_any()
410+
.downcast_ref::<$ARRAYTYPE>()
411+
.expect("compute_utf8_flag_op_scalar failed to downcast array");
412+
413+
if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
414+
let flag = if $FLAG { Some("i") } else { None };
415+
let mut array =
416+
paste::expr! {[<$OP _utf8_scalar>]}(&ll, &string_value, flag)?;
417+
if $NOT {
418+
array = not(&array).unwrap();
419+
}
420+
Ok(Arc::new(array))
421+
} else {
422+
Err(DataFusionError::Internal(format!(
423+
"compute_utf8_flag_op_scalar failed to cast literal value {}",
424+
$RIGHT
425+
)))
426+
}
427+
}};
428+
}
429+
342430
/// Coercion rules for all binary operators. Returns the output type
343431
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
344432
fn common_binary_type(
@@ -368,6 +456,10 @@ fn common_binary_type(
368456
| Operator::Modulo
369457
| Operator::Divide
370458
| Operator::Multiply => numerical_coercion(lhs_type, rhs_type),
459+
Operator::RegexMatch
460+
| Operator::RegexIMatch
461+
| Operator::RegexNotMatch
462+
| Operator::RegexNotIMatch => string_coercion(lhs_type, rhs_type),
371463
};
372464

373465
// re-write the error message of failed coercions to include the operator's information
@@ -406,7 +498,11 @@ pub fn binary_operator_data_type(
406498
| Operator::Lt
407499
| Operator::Gt
408500
| Operator::GtEq
409-
| Operator::LtEq => Ok(DataType::Boolean),
501+
| Operator::LtEq
502+
| Operator::RegexMatch
503+
| Operator::RegexIMatch
504+
| Operator::RegexNotMatch
505+
| Operator::RegexNotIMatch => Ok(DataType::Boolean),
410506
// math operations return the same value as the common coerced type
411507
Operator::Plus
412508
| Operator::Minus
@@ -475,6 +571,34 @@ impl PhysicalExpr for BinaryExpr {
475571
Operator::Modulo => {
476572
binary_primitive_array_op_scalar!(array, scalar.clone(), modulus)
477573
}
574+
Operator::RegexMatch => binary_string_array_flag_op_scalar!(
575+
array,
576+
scalar.clone(),
577+
regexp_is_match,
578+
false,
579+
false
580+
),
581+
Operator::RegexIMatch => binary_string_array_flag_op_scalar!(
582+
array,
583+
scalar.clone(),
584+
regexp_is_match,
585+
false,
586+
true
587+
),
588+
Operator::RegexNotMatch => binary_string_array_flag_op_scalar!(
589+
array,
590+
scalar.clone(),
591+
regexp_is_match,
592+
true,
593+
false
594+
),
595+
Operator::RegexNotIMatch => binary_string_array_flag_op_scalar!(
596+
array,
597+
scalar.clone(),
598+
regexp_is_match,
599+
true,
600+
true
601+
),
478602
// if scalar operation is not supported - fallback to array implementation
479603
_ => None,
480604
}
@@ -547,6 +671,18 @@ impl PhysicalExpr for BinaryExpr {
547671
)));
548672
}
549673
}
674+
Operator::RegexMatch => {
675+
binary_string_array_flag_op!(left, right, regexp_is_match, false, false)
676+
}
677+
Operator::RegexIMatch => {
678+
binary_string_array_flag_op!(left, right, regexp_is_match, false, true)
679+
}
680+
Operator::RegexNotMatch => {
681+
binary_string_array_flag_op!(left, right, regexp_is_match, true, false)
682+
}
683+
Operator::RegexNotIMatch => {
684+
binary_string_array_flag_op!(left, right, regexp_is_match, true, true)
685+
}
550686
};
551687
result.map(|a| ColumnarValue::Array(a))
552688
}
@@ -822,6 +958,102 @@ mod tests {
822958
DataType::Boolean,
823959
vec![true, false]
824960
);
961+
test_coercion!(
962+
StringArray,
963+
DataType::Utf8,
964+
vec!["abc"; 5],
965+
StringArray,
966+
DataType::Utf8,
967+
vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
968+
Operator::RegexMatch,
969+
BooleanArray,
970+
DataType::Boolean,
971+
vec![true, false, true, false, false]
972+
);
973+
test_coercion!(
974+
StringArray,
975+
DataType::Utf8,
976+
vec!["abc"; 5],
977+
StringArray,
978+
DataType::Utf8,
979+
vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
980+
Operator::RegexIMatch,
981+
BooleanArray,
982+
DataType::Boolean,
983+
vec![true, true, true, true, false]
984+
);
985+
test_coercion!(
986+
StringArray,
987+
DataType::Utf8,
988+
vec!["abc"; 5],
989+
StringArray,
990+
DataType::Utf8,
991+
vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
992+
Operator::RegexNotMatch,
993+
BooleanArray,
994+
DataType::Boolean,
995+
vec![false, true, false, true, true]
996+
);
997+
test_coercion!(
998+
StringArray,
999+
DataType::Utf8,
1000+
vec!["abc"; 5],
1001+
StringArray,
1002+
DataType::Utf8,
1003+
vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1004+
Operator::RegexNotIMatch,
1005+
BooleanArray,
1006+
DataType::Boolean,
1007+
vec![false, false, false, false, true]
1008+
);
1009+
test_coercion!(
1010+
LargeStringArray,
1011+
DataType::LargeUtf8,
1012+
vec!["abc"; 5],
1013+
LargeStringArray,
1014+
DataType::LargeUtf8,
1015+
vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1016+
Operator::RegexMatch,
1017+
BooleanArray,
1018+
DataType::Boolean,
1019+
vec![true, false, true, false, false]
1020+
);
1021+
test_coercion!(
1022+
LargeStringArray,
1023+
DataType::LargeUtf8,
1024+
vec!["abc"; 5],
1025+
LargeStringArray,
1026+
DataType::LargeUtf8,
1027+
vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1028+
Operator::RegexIMatch,
1029+
BooleanArray,
1030+
DataType::Boolean,
1031+
vec![true, true, true, true, false]
1032+
);
1033+
test_coercion!(
1034+
LargeStringArray,
1035+
DataType::LargeUtf8,
1036+
vec!["abc"; 5],
1037+
LargeStringArray,
1038+
DataType::LargeUtf8,
1039+
vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1040+
Operator::RegexNotMatch,
1041+
BooleanArray,
1042+
DataType::Boolean,
1043+
vec![false, true, false, true, true]
1044+
);
1045+
test_coercion!(
1046+
LargeStringArray,
1047+
DataType::LargeUtf8,
1048+
vec!["abc"; 5],
1049+
LargeStringArray,
1050+
DataType::LargeUtf8,
1051+
vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"],
1052+
Operator::RegexNotIMatch,
1053+
BooleanArray,
1054+
DataType::Boolean,
1055+
vec![false, false, false, false, true]
1056+
);
8251057
Ok(())
8261058
}
8271059

0 commit comments

Comments
 (0)