Skip to content

Commit 4d89cd7

Browse files
feat: upgrade functions.rs
Upstream is continuing it's migration to UDFs. Ref apache/datafusion#10098 Ref apache/datafusion#10372
1 parent 2be45eb commit 4d89cd7

File tree

1 file changed

+77
-71
lines changed

1 file changed

+77
-71
lines changed

src/functions.rs

Lines changed: 77 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,46 @@ use crate::expr::window::PyWindowFrame;
2424
use crate::expr::PyExpr;
2525
use datafusion::execution::FunctionRegistry;
2626
use datafusion::functions;
27+
use datafusion::functions_aggregate;
2728
use datafusion_common::{Column, ScalarValue, TableReference};
2829
use datafusion_expr::expr::Alias;
2930
use datafusion_expr::{
3031
aggregate_function,
3132
expr::{
32-
find_df_window_func, AggregateFunction, AggregateFunctionDefinition, ScalarFunction, Sort,
33-
WindowFunction,
33+
find_df_window_func, AggregateFunction, AggregateFunctionDefinition, Sort, WindowFunction,
3434
},
35-
lit, BuiltinScalarFunction, Expr, WindowFunctionDefinition,
35+
lit, Expr, WindowFunctionDefinition,
3636
};
3737

38+
#[pyfunction]
39+
#[pyo3(signature = (y, x, distinct = false, filter = None, order_by = None))]
40+
pub fn covar_samp(
41+
y: PyExpr,
42+
x: PyExpr,
43+
distinct: bool,
44+
filter: Option<PyExpr>,
45+
order_by: Option<Vec<PyExpr>>,
46+
// null_treatment: Option<sqlparser::ast::NullTreatment>,
47+
) -> PyExpr {
48+
let filter = filter.map(|x| Box::new(x.expr));
49+
let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>());
50+
functions_aggregate::expr_fn::covar_samp(y.expr, x.expr, distinct, filter, order_by, None)
51+
.into()
52+
}
53+
54+
#[pyfunction]
55+
#[pyo3(signature = (y, x, distinct = false, filter = None, order_by = None))]
56+
pub fn covar(
57+
y: PyExpr,
58+
x: PyExpr,
59+
distinct: bool,
60+
filter: Option<PyExpr>,
61+
order_by: Option<Vec<PyExpr>>,
62+
) -> PyExpr {
63+
// alias for covar_samp
64+
covar_samp(y, x, distinct, filter, order_by)
65+
}
66+
3867
#[pyfunction]
3968
fn in_list(expr: PyExpr, value: Vec<PyExpr>, negated: bool) -> PyExpr {
4069
datafusion_expr::in_list(
@@ -249,27 +278,6 @@ fn window(
249278
})
250279
}
251280

252-
macro_rules! scalar_function {
253-
($NAME: ident, $FUNC: ident) => {
254-
scalar_function!($NAME, $FUNC, stringify!($NAME));
255-
};
256-
257-
($NAME: ident, $FUNC: ident, $DOC: expr) => {
258-
#[doc = $DOC]
259-
#[pyfunction]
260-
#[pyo3(signature = (*args))]
261-
fn $NAME(args: Vec<PyExpr>) -> PyExpr {
262-
let expr = datafusion_expr::Expr::ScalarFunction(ScalarFunction {
263-
func_def: datafusion_expr::ScalarFunctionDefinition::BuiltIn(
264-
BuiltinScalarFunction::$FUNC,
265-
),
266-
args: args.into_iter().map(|e| e.into()).collect(),
267-
});
268-
expr.into()
269-
}
270-
};
271-
}
272-
273281
macro_rules! aggregate_function {
274282
($NAME: ident, $FUNC: ident) => {
275283
aggregate_function!($NAME, $FUNC, stringify!($NAME));
@@ -370,21 +378,21 @@ macro_rules! array_fn {
370378

371379
expr_fn!(abs, num);
372380
expr_fn!(acos, num);
373-
scalar_function!(acosh, Acosh);
381+
expr_fn!(acosh, num);
374382
expr_fn!(ascii, arg1, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character.");
375383
expr_fn!(asin, num);
376-
scalar_function!(asinh, Asinh);
377-
scalar_function!(atan, Atan);
378-
scalar_function!(atanh, Atanh);
379-
scalar_function!(atan2, Atan2);
384+
expr_fn!(asinh, num);
385+
expr_fn!(atan, num);
386+
expr_fn!(atanh, num);
387+
expr_fn!(atan2, y x);
380388
expr_fn!(
381389
bit_length,
382390
arg,
383391
"Returns number of bits in the string (8 times the octet_length)."
384392
);
385393
expr_fn_vec!(btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string.");
386-
scalar_function!(cbrt, Cbrt);
387-
scalar_function!(ceil, Ceil);
394+
expr_fn!(cbrt, num);
395+
expr_fn!(ceil, num);
388396
expr_fn!(
389397
character_length,
390398
string,
@@ -393,44 +401,44 @@ expr_fn!(
393401
expr_fn!(length, string);
394402
expr_fn!(char_length, string);
395403
expr_fn!(chr, arg, "Returns the character with the given code.");
396-
scalar_function!(coalesce, Coalesce);
397-
scalar_function!(cos, Cos);
398-
scalar_function!(cosh, Cosh);
399-
scalar_function!(degrees, Degrees);
404+
expr_fn_vec!(coalesce);
405+
expr_fn!(cos, num);
406+
expr_fn!(cosh, num);
407+
expr_fn!(degrees, num);
400408
expr_fn!(decode, input encoding);
401409
expr_fn!(encode, input encoding);
402-
scalar_function!(exp, Exp);
403-
scalar_function!(factorial, Factorial);
404-
scalar_function!(floor, Floor);
405-
scalar_function!(gcd, Gcd);
406-
scalar_function!(initcap, InitCap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters.");
410+
expr_fn!(exp, num);
411+
expr_fn!(factorial, num);
412+
expr_fn!(floor, num);
413+
expr_fn!(gcd, x y);
414+
expr_fn!(initcap, string, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters.");
407415
expr_fn!(isnan, num);
408-
scalar_function!(iszero, Iszero);
409-
scalar_function!(lcm, Lcm);
410-
scalar_function!(left, Left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters.");
411-
scalar_function!(ln, Ln);
412-
scalar_function!(log, Log);
413-
scalar_function!(log10, Log10);
414-
scalar_function!(log2, Log2);
416+
expr_fn!(iszero, num);
417+
expr_fn!(lcm, x y);
418+
expr_fn!(left, string n, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters.");
419+
expr_fn!(ln, num);
420+
expr_fn!(log, base num);
421+
expr_fn!(log10, num);
422+
expr_fn!(log2, num);
415423
expr_fn!(lower, arg1, "Converts the string to all lower case");
416-
scalar_function!(lpad, Lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right).");
424+
expr_fn_vec!(lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right).");
417425
expr_fn_vec!(ltrim, "Removes the longest string containing only characters in characters (a space by default) from the start of string.");
418426
expr_fn!(
419427
md5,
420428
input_arg,
421429
"Computes the MD5 hash of the argument, with the result written in hexadecimal."
422430
);
423-
scalar_function!(
431+
expr_fn!(
424432
nanvl,
425-
Nanvl,
433+
x y,
426434
"Returns x if x is not NaN otherwise returns y."
427435
);
428436
expr_fn!(nullif, arg_1 arg_2);
429437
expr_fn_vec!(octet_length, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces.");
430-
scalar_function!(pi, Pi);
431-
scalar_function!(power, Power);
432-
scalar_function!(pow, Power);
433-
scalar_function!(radians, Radians);
438+
expr_fn!(pi);
439+
expr_fn!(power, base exponent);
440+
expr_fn!(pow, power, base exponent);
441+
expr_fn!(radians, num);
434442
expr_fn!(regexp_match, input_arg1 input_arg2);
435443
expr_fn!(
436444
regexp_replace,
@@ -443,31 +451,31 @@ expr_fn!(
443451
string from to,
444452
"Replaces all occurrences in string of substring from with substring to."
445453
);
446-
scalar_function!(
454+
expr_fn!(
447455
reverse,
448-
Reverse,
456+
string,
449457
"Reverses the order of the characters in the string."
450458
);
451-
scalar_function!(right, Right, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters.");
452-
scalar_function!(round, Round);
453-
scalar_function!(rpad, Rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.");
459+
expr_fn!(right, string n, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters.");
460+
expr_fn_vec!(round);
461+
expr_fn_vec!(rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.");
454462
expr_fn_vec!(rtrim, "Removes the longest string containing only characters in characters (a space by default) from the end of string.");
455463
expr_fn!(sha224, input_arg1);
456464
expr_fn!(sha256, input_arg1);
457465
expr_fn!(sha384, input_arg1);
458466
expr_fn!(sha512, input_arg1);
459-
scalar_function!(signum, Signum);
460-
scalar_function!(sin, Sin);
461-
scalar_function!(sinh, Sinh);
467+
expr_fn!(signum, num);
468+
expr_fn!(sin, num);
469+
expr_fn!(sinh, num);
462470
expr_fn!(
463471
split_part,
464472
string delimiter index,
465473
"Splits string at occurrences of delimiter and returns the n'th field (counting from one)."
466474
);
467-
scalar_function!(sqrt, Sqrt);
475+
expr_fn!(sqrt, num);
468476
expr_fn!(starts_with, arg1 arg2, "Returns true if string starts with prefix.");
469-
scalar_function!(strpos, Strpos, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)");
470-
scalar_function!(substr, Substr);
477+
expr_fn!(strpos, string substring, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)");
478+
expr_fn!(substr, string position);
471479
expr_fn!(tan, num);
472480
expr_fn!(tanh, num);
473481
expr_fn!(
@@ -488,15 +496,15 @@ expr_fn!(date_trunc, part date);
488496
expr_fn!(datetrunc, date_trunc, part date);
489497
expr_fn!(date_bin, stride source origin);
490498

491-
scalar_function!(translate, Translate, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.");
499+
expr_fn!(translate, string from to, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.");
492500
expr_fn_vec!(trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string.");
493-
scalar_function!(trunc, Trunc);
501+
expr_fn_vec!(trunc);
494502
expr_fn!(upper, arg1, "Converts the string to all upper case.");
495503
expr_fn!(uuid);
496-
expr_fn!(r#struct, args); // Use raw identifier since struct is a keyword
504+
expr_fn_vec!(r#struct); // Use raw identifier since struct is a keyword
497505
expr_fn!(from_unixtime, unixtime);
498506
expr_fn!(arrow_typeof, arg_1);
499-
scalar_function!(random, Random);
507+
expr_fn!(random);
500508

501509
// Array Functions
502510
array_fn!(array_append, array element);
@@ -565,9 +573,7 @@ aggregate_function!(array_agg, ArrayAgg);
565573
aggregate_function!(avg, Avg);
566574
aggregate_function!(corr, Correlation);
567575
aggregate_function!(count, Count);
568-
aggregate_function!(covar, Covariance);
569576
aggregate_function!(covar_pop, CovariancePop);
570-
aggregate_function!(covar_samp, Covariance);
571577
aggregate_function!(grouping, Grouping);
572578
aggregate_function!(max, Max);
573579
aggregate_function!(mean, Avg);

0 commit comments

Comments
 (0)