Skip to content

Commit 234217e

Browse files
authored
feat:implement sql style 'substr_index' string function (#8272)
* feat:implement sql style 'substr_index' string function * code format * code format * code format * fix index bound issue * code format * code format * add args len check * add sql tests * code format * doc format
1 parent f29bcf3 commit 234217e

File tree

11 files changed

+215
-3
lines changed

11 files changed

+215
-3
lines changed

datafusion/expr/src/built_in_function.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ pub enum BuiltinScalarFunction {
302302
OverLay,
303303
/// levenshtein
304304
Levenshtein,
305+
/// substr_index
306+
SubstrIndex,
305307
}
306308

307309
/// Maps the sql function name to `BuiltinScalarFunction`
@@ -470,6 +472,7 @@ impl BuiltinScalarFunction {
470472
BuiltinScalarFunction::ArrowTypeof => Volatility::Immutable,
471473
BuiltinScalarFunction::OverLay => Volatility::Immutable,
472474
BuiltinScalarFunction::Levenshtein => Volatility::Immutable,
475+
BuiltinScalarFunction::SubstrIndex => Volatility::Immutable,
473476

474477
// Stable builtin functions
475478
BuiltinScalarFunction::Now => Volatility::Stable,
@@ -773,6 +776,9 @@ impl BuiltinScalarFunction {
773776
return plan_err!("The to_hex function can only accept integers.");
774777
}
775778
}),
779+
BuiltinScalarFunction::SubstrIndex => {
780+
utf8_to_str_type(&input_expr_types[0], "substr_index")
781+
}
776782
BuiltinScalarFunction::ToTimestamp => Ok(match &input_expr_types[0] {
777783
Int64 => Timestamp(Second, None),
778784
_ => Timestamp(Nanosecond, None),
@@ -1235,6 +1241,14 @@ impl BuiltinScalarFunction {
12351241
self.volatility(),
12361242
),
12371243

1244+
BuiltinScalarFunction::SubstrIndex => Signature::one_of(
1245+
vec![
1246+
Exact(vec![Utf8, Utf8, Int64]),
1247+
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
1248+
],
1249+
self.volatility(),
1250+
),
1251+
12381252
BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => {
12391253
Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility())
12401254
}
@@ -1486,6 +1500,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
14861500
BuiltinScalarFunction::Upper => &["upper"],
14871501
BuiltinScalarFunction::Uuid => &["uuid"],
14881502
BuiltinScalarFunction::Levenshtein => &["levenshtein"],
1503+
BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"],
14891504

14901505
// regex functions
14911506
BuiltinScalarFunction::RegexpMatch => &["regexp_match"],

datafusion/expr/src/expr_fn.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,7 @@ scalar_expr!(
916916

917917
scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type");
918918
scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings");
919+
scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter");
919920

920921
scalar_expr!(
921922
Struct,
@@ -1205,6 +1206,7 @@ mod test {
12051206
test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len);
12061207
test_nary_scalar_expr!(OverLay, overlay, string, characters, position);
12071208
test_scalar_expr!(Levenshtein, levenshtein, string1, string2);
1209+
test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count);
12081210
}
12091211

12101212
#[test]

datafusion/physical-expr/src/functions.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,29 @@ pub fn create_physical_fun(
862862
))),
863863
})
864864
}
865+
BuiltinScalarFunction::SubstrIndex => {
866+
Arc::new(|args| match args[0].data_type() {
867+
DataType::Utf8 => {
868+
let func = invoke_if_unicode_expressions_feature_flag!(
869+
substr_index,
870+
i32,
871+
"substr_index"
872+
);
873+
make_scalar_function(func)(args)
874+
}
875+
DataType::LargeUtf8 => {
876+
let func = invoke_if_unicode_expressions_feature_flag!(
877+
substr_index,
878+
i64,
879+
"substr_index"
880+
);
881+
make_scalar_function(func)(args)
882+
}
883+
other => Err(DataFusionError::Internal(format!(
884+
"Unsupported data type {other:?} for function substr_index",
885+
))),
886+
})
887+
}
865888
})
866889
}
867890

datafusion/physical-expr/src/unicode_expressions.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,68 @@ pub fn translate<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
455455

456456
Ok(Arc::new(result) as ArrayRef)
457457
}
458+
459+
/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned.
460+
/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www
461+
/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
462+
/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
463+
/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
464+
pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
465+
if args.len() != 3 {
466+
return internal_err!(
467+
"substr_index was called with {} arguments. It requires 3.",
468+
args.len()
469+
);
470+
}
471+
472+
let string_array = as_generic_string_array::<T>(&args[0])?;
473+
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
474+
let count_array = as_int64_array(&args[2])?;
475+
476+
let result = string_array
477+
.iter()
478+
.zip(delimiter_array.iter())
479+
.zip(count_array.iter())
480+
.map(|((string, delimiter), n)| match (string, delimiter, n) {
481+
(Some(string), Some(delimiter), Some(n)) => {
482+
let mut res = String::new();
483+
match n {
484+
0 => {
485+
"".to_string();
486+
}
487+
_other => {
488+
if n > 0 {
489+
let idx = string
490+
.split(delimiter)
491+
.take(n as usize)
492+
.fold(0, |len, x| len + x.len() + delimiter.len())
493+
- delimiter.len();
494+
res.push_str(if idx >= string.len() {
495+
string
496+
} else {
497+
&string[..idx]
498+
});
499+
} else {
500+
let idx = (string.split(delimiter).take((-n) as usize).fold(
501+
string.len() as isize,
502+
|len, x| {
503+
len - x.len() as isize - delimiter.len() as isize
504+
},
505+
) + delimiter.len() as isize)
506+
as usize;
507+
res.push_str(if idx >= string.len() {
508+
string
509+
} else {
510+
&string[idx..]
511+
});
512+
}
513+
}
514+
}
515+
Some(res)
516+
}
517+
_ => None,
518+
})
519+
.collect::<GenericStringArray<T>>();
520+
521+
Ok(Arc::new(result) as ArrayRef)
522+
}

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,7 @@ enum ScalarFunction {
641641
ArrayExcept = 123;
642642
ArrayPopFront = 124;
643643
Levenshtein = 125;
644+
SubstrIndex = 126;
644645
}
645646

646647
message ScalarFunctionNode {

datafusion/proto/src/generated/pbjson.rs

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/generated/prost.rs

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ use datafusion_expr::{
5555
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power,
5656
radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right,
5757
round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part,
58-
sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substring, tan, tanh,
59-
to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_nanos,
60-
to_timestamp_seconds, translate, trim, trunc, upper, uuid,
58+
sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index,
59+
substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis,
60+
to_timestamp_nanos, to_timestamp_seconds, translate, trim, trunc, upper, uuid,
6161
window_frame::regularize,
6262
AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction,
6363
Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet,
@@ -551,6 +551,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
551551
ScalarFunction::ArrowTypeof => Self::ArrowTypeof,
552552
ScalarFunction::OverLay => Self::OverLay,
553553
ScalarFunction::Levenshtein => Self::Levenshtein,
554+
ScalarFunction::SubstrIndex => Self::SubstrIndex,
554555
}
555556
}
556557
}
@@ -1716,6 +1717,11 @@ pub fn parse_expr(
17161717
.map(|expr| parse_expr(expr, registry))
17171718
.collect::<Result<Vec<_>, _>>()?,
17181719
)),
1720+
ScalarFunction::SubstrIndex => Ok(substr_index(
1721+
parse_expr(&args[0], registry)?,
1722+
parse_expr(&args[1], registry)?,
1723+
parse_expr(&args[2], registry)?,
1724+
)),
17191725
ScalarFunction::StructFun => {
17201726
Ok(struct_fun(parse_expr(&args[0], registry)?))
17211727
}

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,6 +1583,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
15831583
BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof,
15841584
BuiltinScalarFunction::OverLay => Self::OverLay,
15851585
BuiltinScalarFunction::Levenshtein => Self::Levenshtein,
1586+
BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex,
15861587
};
15871588

15881589
Ok(scalar_function)

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,3 +877,78 @@ query ?
877877
SELECT levenshtein(NULL, NULL)
878878
----
879879
NULL
880+
881+
query T
882+
SELECT substr_index('www.apache.org', '.', 1)
883+
----
884+
www
885+
886+
query T
887+
SELECT substr_index('www.apache.org', '.', 2)
888+
----
889+
www.apache
890+
891+
query T
892+
SELECT substr_index('www.apache.org', '.', -1)
893+
----
894+
org
895+
896+
query T
897+
SELECT substr_index('www.apache.org', '.', -2)
898+
----
899+
apache.org
900+
901+
query T
902+
SELECT substr_index('www.apache.org', 'ac', 1)
903+
----
904+
www.ap
905+
906+
query T
907+
SELECT substr_index('www.apache.org', 'ac', -1)
908+
----
909+
he.org
910+
911+
query T
912+
SELECT substr_index('www.apache.org', 'ac', 2)
913+
----
914+
www.apache.org
915+
916+
query T
917+
SELECT substr_index('www.apache.org', 'ac', -2)
918+
----
919+
www.apache.org
920+
921+
query ?
922+
SELECT substr_index(NULL, 'ac', 1)
923+
----
924+
NULL
925+
926+
query T
927+
SELECT substr_index('www.apache.org', NULL, 1)
928+
----
929+
NULL
930+
931+
query T
932+
SELECT substr_index('www.apache.org', 'ac', NULL)
933+
----
934+
NULL
935+
936+
query T
937+
SELECT substr_index('', 'ac', 1)
938+
----
939+
(empty)
940+
941+
query T
942+
SELECT substr_index('www.apache.org', '', 1)
943+
----
944+
(empty)
945+
946+
query T
947+
SELECT substr_index('www.apache.org', 'ac', 0)
948+
----
949+
(empty)
950+
951+
query ?
952+
SELECT substr_index(NULL, NULL, NULL)
953+
----
954+
NULL

0 commit comments

Comments
 (0)