Skip to content

Commit 78f8ef1

Browse files
Omega359alamb
andauthored
move Trunc, Cot, Round, iszero functions to datafusion-functions (#10000)
* move Floor, Gcd, Lcm, Pi to datafusion-functions * remove floor fn * move Trunc, Cot, Round, iszero functions to datafusion-functions * Make mod iszero public, minor ordering change to keep the alphabetical ordering theme. --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 0088c28 commit 78f8ef1

File tree

17 files changed

+1061
-692
lines changed

17 files changed

+1061
-692
lines changed

datafusion/expr/src/built_in_function.rs

Lines changed: 7 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,8 @@ pub enum BuiltinScalarFunction {
4545
Exp,
4646
/// factorial
4747
Factorial,
48-
/// iszero
49-
Iszero,
5048
/// nanvl
5149
Nanvl,
52-
/// round
53-
Round,
54-
/// trunc
55-
Trunc,
56-
/// cot
57-
Cot,
58-
5950
// string functions
6051
/// concat
6152
Concat,
@@ -123,11 +114,7 @@ impl BuiltinScalarFunction {
123114
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
124115
BuiltinScalarFunction::Exp => Volatility::Immutable,
125116
BuiltinScalarFunction::Factorial => Volatility::Immutable,
126-
BuiltinScalarFunction::Iszero => Volatility::Immutable,
127117
BuiltinScalarFunction::Nanvl => Volatility::Immutable,
128-
BuiltinScalarFunction::Round => Volatility::Immutable,
129-
BuiltinScalarFunction::Cot => Volatility::Immutable,
130-
BuiltinScalarFunction::Trunc => Volatility::Immutable,
131118
BuiltinScalarFunction::Concat => Volatility::Immutable,
132119
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
133120
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
@@ -175,16 +162,12 @@ impl BuiltinScalarFunction {
175162
_ => Ok(Float64),
176163
},
177164

178-
BuiltinScalarFunction::Iszero => Ok(Boolean),
179-
180-
BuiltinScalarFunction::Ceil
181-
| BuiltinScalarFunction::Exp
182-
| BuiltinScalarFunction::Round
183-
| BuiltinScalarFunction::Trunc
184-
| BuiltinScalarFunction::Cot => match input_expr_types[0] {
185-
Float32 => Ok(Float32),
186-
_ => Ok(Float64),
187-
},
165+
BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => {
166+
match input_expr_types[0] {
167+
Float32 => Ok(Float32),
168+
_ => Ok(Float64),
169+
}
170+
}
188171
}
189172
}
190173

@@ -217,45 +200,21 @@ impl BuiltinScalarFunction {
217200
self.volatility(),
218201
),
219202
BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()),
220-
BuiltinScalarFunction::Round => Signature::one_of(
221-
vec![
222-
Exact(vec![Float64, Int64]),
223-
Exact(vec![Float32, Int64]),
224-
Exact(vec![Float64]),
225-
Exact(vec![Float32]),
226-
],
227-
self.volatility(),
228-
),
229-
BuiltinScalarFunction::Trunc => Signature::one_of(
230-
vec![
231-
Exact(vec![Float32, Int64]),
232-
Exact(vec![Float64, Int64]),
233-
Exact(vec![Float64]),
234-
Exact(vec![Float32]),
235-
],
236-
self.volatility(),
237-
),
238203
BuiltinScalarFunction::Nanvl => Signature::one_of(
239204
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
240205
self.volatility(),
241206
),
242207
BuiltinScalarFunction::Factorial => {
243208
Signature::uniform(1, vec![Int64], self.volatility())
244209
}
245-
BuiltinScalarFunction::Ceil
246-
| BuiltinScalarFunction::Exp
247-
| BuiltinScalarFunction::Cot => {
210+
BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => {
248211
// math expressions expect 1 argument of type f64 or f32
249212
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
250213
// return the best approximation for it (in f64).
251214
// We accept f32 because in this case it is clear that the best approximation
252215
// will be as good as the number of digits in the number
253216
Signature::uniform(1, vec![Float64, Float32], self.volatility())
254217
}
255-
BuiltinScalarFunction::Iszero => Signature::one_of(
256-
vec![Exact(vec![Float32]), Exact(vec![Float64])],
257-
self.volatility(),
258-
),
259218
}
260219
}
261220

@@ -268,8 +227,6 @@ impl BuiltinScalarFunction {
268227
BuiltinScalarFunction::Ceil
269228
| BuiltinScalarFunction::Exp
270229
| BuiltinScalarFunction::Factorial
271-
| BuiltinScalarFunction::Round
272-
| BuiltinScalarFunction::Trunc
273230
) {
274231
Some(vec![Some(true)])
275232
} else {
@@ -281,14 +238,10 @@ impl BuiltinScalarFunction {
281238
pub fn aliases(&self) -> &'static [&'static str] {
282239
match self {
283240
BuiltinScalarFunction::Ceil => &["ceil"],
284-
BuiltinScalarFunction::Cot => &["cot"],
285241
BuiltinScalarFunction::Exp => &["exp"],
286242
BuiltinScalarFunction::Factorial => &["factorial"],
287-
BuiltinScalarFunction::Iszero => &["iszero"],
288243
BuiltinScalarFunction::Nanvl => &["nanvl"],
289244
BuiltinScalarFunction::Random => &["random"],
290-
BuiltinScalarFunction::Round => &["round"],
291-
BuiltinScalarFunction::Trunc => &["trunc"],
292245

293246
// conditional functions
294247
BuiltinScalarFunction::Coalesce => &["coalesce"],

datafusion/expr/src/expr_fn.rs

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -530,20 +530,14 @@ macro_rules! nary_scalar_expr {
530530
// generate methods for creating the supported unary/binary expressions
531531

532532
// math functions
533-
scalar_expr!(Cot, cot, num, "cotangent of a number");
534533
scalar_expr!(Factorial, factorial, num, "factorial");
535534
scalar_expr!(
536535
Ceil,
537536
ceil,
538537
num,
539538
"nearest integer greater than or equal to argument"
540539
);
541-
nary_scalar_expr!(Round, round, "round to nearest integer");
542-
nary_scalar_expr!(
543-
Trunc,
544-
trunc,
545-
"truncate toward zero, with optional precision"
546-
);
540+
547541
scalar_expr!(Exp, exp, num, "exponential");
548542

549543
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
@@ -557,12 +551,6 @@ nary_scalar_expr!(
557551
);
558552
nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
559553
scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y");
560-
scalar_expr!(
561-
Iszero,
562-
iszero,
563-
num,
564-
"returns true if a given number is +0.0 or -0.0 otherwise returns false"
565-
);
566554

567555
/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
568556
pub fn case(expr: Expr) -> CaseBuilder {
@@ -872,12 +860,6 @@ impl WindowUDFImpl for SimpleWindowUDF {
872860
}
873861

874862
/// Calls a named built in function
875-
/// ```
876-
/// use datafusion_expr::{col, lit, call_fn};
877-
///
878-
/// // create the expression trunc(x) < 0.2
879-
/// let expr = call_fn("trunc", vec![col("x")]).unwrap().lt(lit(0.2));
880-
/// ```
881863
pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
882864
match name.as_ref().parse::<BuiltinScalarFunction>() {
883865
Ok(fun) => Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))),
@@ -935,38 +917,12 @@ mod test {
935917
};
936918
}
937919

938-
macro_rules! test_nary_scalar_expr {
939-
($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
940-
let expected = [$(stringify!($arg)),*];
941-
let result = $FUNC(
942-
vec![
943-
$(
944-
col(stringify!($arg.to_string()))
945-
),*
946-
]
947-
);
948-
if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result {
949-
let name = built_in_function::BuiltinScalarFunction::$ENUM;
950-
assert_eq!(name, fun);
951-
assert_eq!(expected.len(), args.len());
952-
} else {
953-
assert!(false, "unexpected: {:?}", result);
954-
}
955-
};
956-
}
957-
958920
#[test]
959921
fn scalar_function_definitions() {
960-
test_unary_scalar_expr!(Cot, cot);
961922
test_unary_scalar_expr!(Factorial, factorial);
962923
test_unary_scalar_expr!(Ceil, ceil);
963-
test_nary_scalar_expr!(Round, round, input);
964-
test_nary_scalar_expr!(Round, round, input, decimal_places);
965-
test_nary_scalar_expr!(Trunc, trunc, num);
966-
test_nary_scalar_expr!(Trunc, trunc, num, precision);
967924
test_unary_scalar_expr!(Exp, exp);
968925
test_scalar_expr!(Nanvl, nanvl, x, y);
969-
test_scalar_expr!(Iszero, iszero, input);
970926

971927
test_scalar_expr!(InitCap, initcap, string);
972928
test_scalar_expr!(EndsWith, ends_with, string, characters);
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::sync::Arc;
20+
21+
use arrow::array::{ArrayRef, Float32Array, Float64Array};
22+
use arrow::datatypes::DataType;
23+
use arrow::datatypes::DataType::{Float32, Float64};
24+
25+
use datafusion_common::{exec_err, DataFusionError, Result};
26+
use datafusion_expr::ColumnarValue;
27+
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28+
29+
use crate::utils::make_scalar_function;
30+
31+
#[derive(Debug)]
32+
pub struct CotFunc {
33+
signature: Signature,
34+
}
35+
36+
impl Default for CotFunc {
37+
fn default() -> Self {
38+
CotFunc::new()
39+
}
40+
}
41+
42+
impl CotFunc {
43+
pub fn new() -> Self {
44+
use DataType::*;
45+
Self {
46+
// math expressions expect 1 argument of type f64 or f32
47+
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
48+
// return the best approximation for it (in f64).
49+
// We accept f32 because in this case it is clear that the best approximation
50+
// will be as good as the number of digits in the number
51+
signature: Signature::uniform(
52+
1,
53+
vec![Float64, Float32],
54+
Volatility::Immutable,
55+
),
56+
}
57+
}
58+
}
59+
60+
impl ScalarUDFImpl for CotFunc {
61+
fn as_any(&self) -> &dyn Any {
62+
self
63+
}
64+
65+
fn name(&self) -> &str {
66+
"cot"
67+
}
68+
69+
fn signature(&self) -> &Signature {
70+
&self.signature
71+
}
72+
73+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
74+
match arg_types[0] {
75+
Float32 => Ok(Float32),
76+
_ => Ok(Float64),
77+
}
78+
}
79+
80+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
81+
make_scalar_function(cot, vec![])(args)
82+
}
83+
}
84+
85+
///cot SQL function
86+
fn cot(args: &[ArrayRef]) -> Result<ArrayRef> {
87+
match args[0].data_type() {
88+
Float64 => Ok(Arc::new(make_function_scalar_inputs!(
89+
&args[0],
90+
"x",
91+
Float64Array,
92+
{ compute_cot64 }
93+
)) as ArrayRef),
94+
Float32 => Ok(Arc::new(make_function_scalar_inputs!(
95+
&args[0],
96+
"x",
97+
Float32Array,
98+
{ compute_cot32 }
99+
)) as ArrayRef),
100+
other => exec_err!("Unsupported data type {other:?} for function cot"),
101+
}
102+
}
103+
104+
fn compute_cot32(x: f32) -> f32 {
105+
let a = f32::tan(x);
106+
1.0 / a
107+
}
108+
109+
fn compute_cot64(x: f64) -> f64 {
110+
let a = f64::tan(x);
111+
1.0 / a
112+
}
113+
114+
#[cfg(test)]
115+
mod test {
116+
use crate::math::cot::cot;
117+
use arrow::array::{ArrayRef, Float32Array, Float64Array};
118+
use datafusion_common::cast::{as_float32_array, as_float64_array};
119+
use std::sync::Arc;
120+
121+
#[test]
122+
fn test_cot_f32() {
123+
let args: Vec<ArrayRef> =
124+
vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
125+
let result = cot(&args).expect("failed to initialize function cot");
126+
let floats =
127+
as_float32_array(&result).expect("failed to initialize function cot");
128+
129+
let expected = Float32Array::from(vec![
130+
-1.986_460_4,
131+
-0.156_119_96,
132+
-0.501_202_8,
133+
0.156_119_96,
134+
]);
135+
136+
let eps = 1e-6;
137+
assert_eq!(floats.len(), 4);
138+
assert!((floats.value(0) - expected.value(0)).abs() < eps);
139+
assert!((floats.value(1) - expected.value(1)).abs() < eps);
140+
assert!((floats.value(2) - expected.value(2)).abs() < eps);
141+
assert!((floats.value(3) - expected.value(3)).abs() < eps);
142+
}
143+
144+
#[test]
145+
fn test_cot_f64() {
146+
let args: Vec<ArrayRef> =
147+
vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
148+
let result = cot(&args).expect("failed to initialize function cot");
149+
let floats =
150+
as_float64_array(&result).expect("failed to initialize function cot");
151+
152+
let expected = Float64Array::from(vec![
153+
-1.986_458_685_881_4,
154+
-0.156_119_952_161_6,
155+
-0.501_202_783_380_1,
156+
0.156_119_952_161_6,
157+
]);
158+
159+
let eps = 1e-12;
160+
assert_eq!(floats.len(), 4);
161+
assert!((floats.value(0) - expected.value(0)).abs() < eps);
162+
assert!((floats.value(1) - expected.value(1)).abs() < eps);
163+
assert!((floats.value(2) - expected.value(2)).abs() < eps);
164+
assert!((floats.value(3) - expected.value(3)).abs() < eps);
165+
}
166+
}

0 commit comments

Comments
 (0)