Skip to content

Commit 219a133

Browse files
authored
chore: extract math_funcs expressions to folders based on spark grouping (apache#1219)
* extract math_funcs expressions to folders based on spark grouping * fix merge conflicts and move chr to `string_funcs`
1 parent ba08511 commit 219a133

File tree

21 files changed

+661
-589
lines changed

21 files changed

+661
-589
lines changed

benches/decimal_div.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use arrow::compute::cast;
1919
use arrow_array::builder::Decimal128Builder;
2020
use arrow_schema::DataType;
2121
use criterion::{black_box, criterion_group, criterion_main, Criterion};
22-
use datafusion_comet_spark_expr::scalar_funcs::spark_decimal_div;
22+
use datafusion_comet_spark_expr::spark_decimal_div;
2323
use datafusion_expr::ColumnarValue;
2424
use std::sync::Arc;
2525

src/comet_scalar_funcs.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
// under the License.
1717

1818
use crate::hash_funcs::*;
19-
use crate::scalar_funcs::{
20-
spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
21-
spark_round, spark_unhex, spark_unscaled_value, SparkChrFunc,
19+
use crate::{
20+
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
21+
spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_unhex,
22+
spark_unscaled_value, SparkChrFunc,
2223
};
23-
use crate::{spark_date_add, spark_date_sub, spark_read_side_padding};
2424
use arrow_schema::DataType;
2525
use datafusion_common::{DataFusionError, Result as DataFusionResult};
2626
use datafusion_expr::registry::FunctionRegistry;

src/hash_funcs/sha2.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::scalar_funcs::hex_strings;
18+
use crate::math_funcs::hex::hex_strings;
1919
use arrow_array::{Array, StringArray};
2020
use datafusion::functions::crypto::{sha224, sha256, sha384, sha512};
2121
use datafusion_common::cast::as_binary_array;

src/lib.rs

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,22 @@
2121

2222
mod error;
2323

24-
mod checkoverflow;
25-
pub use checkoverflow::CheckOverflow;
26-
2724
mod kernels;
28-
pub mod scalar_funcs;
2925
mod schema_adapter;
3026
mod static_invoke;
3127
pub use schema_adapter::SparkSchemaAdapterFactory;
3228
pub use static_invoke::*;
3329

34-
mod negative;
3530
mod struct_funcs;
36-
pub use negative::{create_negate_expr, NegativeExpr};
37-
mod normalize_nan;
31+
pub use struct_funcs::{CreateNamedStruct, GetStructField};
3832

3933
mod json_funcs;
4034
pub mod test_common;
4135
pub mod timezone;
4236
mod unbound;
4337
pub use unbound::UnboundColumn;
44-
pub mod utils;
45-
pub use normalize_nan::NormalizeNaNAndZero;
4638
mod predicate_funcs;
39+
pub mod utils;
4740
pub use predicate_funcs::{spark_isnan, RLike};
4841

4942
mod agg_funcs;
@@ -57,24 +50,30 @@ mod string_funcs;
5750
mod datetime_funcs;
5851
pub use agg_funcs::*;
5952

60-
pub use crate::{CreateNamedStruct, GetStructField};
61-
pub use crate::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr};
6253
pub use cast::{spark_cast, Cast, SparkCastOptions};
6354
mod conditional_funcs;
6455
mod conversion_funcs;
56+
mod math_funcs;
6557

6658
pub use array_funcs::*;
6759
pub use bitwise_funcs::*;
6860
pub use conditional_funcs::*;
6961
pub use conversion_funcs::*;
7062

7163
pub use comet_scalar_funcs::create_comet_physical_fun;
72-
pub use datetime_funcs::*;
64+
pub use datetime_funcs::{
65+
spark_date_add, spark_date_sub, DateTruncExpr, HourExpr, MinuteExpr, SecondExpr,
66+
TimestampTruncExpr,
67+
};
7368
pub use error::{SparkError, SparkResult};
7469
pub use hash_funcs::*;
7570
pub use json_funcs::ToJson;
71+
pub use math_funcs::{
72+
create_negate_expr, spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_make_decimal,
73+
spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr,
74+
NormalizeNaNAndZero,
75+
};
7676
pub use string_funcs::*;
77-
pub use struct_funcs::*;
7877

7978
/// Spark supports three evaluation modes when evaluating expressions, which affect
8079
/// the behavior when processing input values that are invalid or would result in an

src/math_funcs/ceil.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::downcast_compute_op;
19+
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
20+
use arrow::array::{Float32Array, Float64Array, Int64Array};
21+
use arrow_array::{Array, ArrowNativeTypeOp};
22+
use arrow_schema::DataType;
23+
use datafusion::physical_plan::ColumnarValue;
24+
use datafusion_common::{DataFusionError, ScalarValue};
25+
use num::integer::div_ceil;
26+
use std::sync::Arc;
27+
28+
/// `ceil` function that simulates Spark `ceil` expression
29+
pub fn spark_ceil(
30+
args: &[ColumnarValue],
31+
data_type: &DataType,
32+
) -> Result<ColumnarValue, DataFusionError> {
33+
let value = &args[0];
34+
match value {
35+
ColumnarValue::Array(array) => match array.data_type() {
36+
DataType::Float32 => {
37+
let result = downcast_compute_op!(array, "ceil", ceil, Float32Array, Int64Array);
38+
Ok(ColumnarValue::Array(result?))
39+
}
40+
DataType::Float64 => {
41+
let result = downcast_compute_op!(array, "ceil", ceil, Float64Array, Int64Array);
42+
Ok(ColumnarValue::Array(result?))
43+
}
44+
DataType::Int64 => {
45+
let result = array.as_any().downcast_ref::<Int64Array>().unwrap();
46+
Ok(ColumnarValue::Array(Arc::new(result.clone())))
47+
}
48+
DataType::Decimal128(_, scale) if *scale > 0 => {
49+
let f = decimal_ceil_f(scale);
50+
let (precision, scale) = get_precision_scale(data_type);
51+
make_decimal_array(array, precision, scale, &f)
52+
}
53+
other => Err(DataFusionError::Internal(format!(
54+
"Unsupported data type {:?} for function ceil",
55+
other,
56+
))),
57+
},
58+
ColumnarValue::Scalar(a) => match a {
59+
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
60+
a.map(|x| x.ceil() as i64),
61+
))),
62+
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
63+
a.map(|x| x.ceil() as i64),
64+
))),
65+
ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
66+
ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
67+
let f = decimal_ceil_f(scale);
68+
let (precision, scale) = get_precision_scale(data_type);
69+
make_decimal_scalar(a, precision, scale, &f)
70+
}
71+
_ => Err(DataFusionError::Internal(format!(
72+
"Unsupported data type {:?} for function ceil",
73+
value.data_type(),
74+
))),
75+
},
76+
}
77+
}
78+
79+
#[inline]
80+
fn decimal_ceil_f(scale: &i8) -> impl Fn(i128) -> i128 {
81+
let div = 10_i128.pow_wrapping(*scale as u32);
82+
move |x: i128| div_ceil(x, div)
83+
}

src/math_funcs/div.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::math_funcs::utils::get_precision_scale;
19+
use arrow::{
20+
array::{ArrayRef, AsArray},
21+
datatypes::Decimal128Type,
22+
};
23+
use arrow_array::{Array, Decimal128Array};
24+
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
25+
use datafusion::physical_plan::ColumnarValue;
26+
use datafusion_common::DataFusionError;
27+
use num::{BigInt, Signed, ToPrimitive};
28+
use std::sync::Arc;
29+
30+
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
31+
// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to
32+
// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since
33+
// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale >
34+
// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt.
35+
pub fn spark_decimal_div(
36+
args: &[ColumnarValue],
37+
data_type: &DataType,
38+
) -> Result<ColumnarValue, DataFusionError> {
39+
let left = &args[0];
40+
let right = &args[1];
41+
let (p3, s3) = get_precision_scale(data_type);
42+
43+
let (left, right): (ArrayRef, ArrayRef) = match (left, right) {
44+
(ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)),
45+
(ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
46+
(l.to_array_of_size(r.len())?, Arc::clone(r))
47+
}
48+
(ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
49+
(Arc::clone(l), r.to_array_of_size(l.len())?)
50+
}
51+
(ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?),
52+
};
53+
let left = left.as_primitive::<Decimal128Type>();
54+
let right = right.as_primitive::<Decimal128Type>();
55+
let (p1, s1) = get_precision_scale(left.data_type());
56+
let (p2, s2) = get_precision_scale(right.data_type());
57+
58+
let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32);
59+
let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32);
60+
let result: Decimal128Array = if p1 as u32 + l_exp > DECIMAL128_MAX_PRECISION as u32
61+
|| p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32
62+
{
63+
let ten = BigInt::from(10);
64+
let l_mul = ten.pow(l_exp);
65+
let r_mul = ten.pow(r_exp);
66+
let five = BigInt::from(5);
67+
let zero = BigInt::from(0);
68+
arrow::compute::kernels::arity::binary(left, right, |l, r| {
69+
let l = BigInt::from(l) * &l_mul;
70+
let r = BigInt::from(r) * &r_mul;
71+
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
72+
let res = if div.is_negative() {
73+
div - &five
74+
} else {
75+
div + &five
76+
} / &ten;
77+
res.to_i128().unwrap_or(i128::MAX)
78+
})?
79+
} else {
80+
let l_mul = 10_i128.pow(l_exp);
81+
let r_mul = 10_i128.pow(r_exp);
82+
arrow::compute::kernels::arity::binary(left, right, |l, r| {
83+
let l = l * l_mul;
84+
let r = r * r_mul;
85+
let div = if r == 0 { 0 } else { l / r };
86+
let res = if div.is_negative() { div - 5 } else { div + 5 } / 10;
87+
res.to_i128().unwrap_or(i128::MAX)
88+
})?
89+
};
90+
let result = result.with_data_type(DataType::Decimal128(p3, s3));
91+
Ok(ColumnarValue::Array(Arc::new(result)))
92+
}

src/math_funcs/floor.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::downcast_compute_op;
19+
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
20+
use arrow::array::{Float32Array, Float64Array, Int64Array};
21+
use arrow_array::{Array, ArrowNativeTypeOp};
22+
use arrow_schema::DataType;
23+
use datafusion::physical_plan::ColumnarValue;
24+
use datafusion_common::{DataFusionError, ScalarValue};
25+
use num::integer::div_floor;
26+
use std::sync::Arc;
27+
28+
/// `floor` function that simulates Spark `floor` expression
29+
pub fn spark_floor(
30+
args: &[ColumnarValue],
31+
data_type: &DataType,
32+
) -> Result<ColumnarValue, DataFusionError> {
33+
let value = &args[0];
34+
match value {
35+
ColumnarValue::Array(array) => match array.data_type() {
36+
DataType::Float32 => {
37+
let result = downcast_compute_op!(array, "floor", floor, Float32Array, Int64Array);
38+
Ok(ColumnarValue::Array(result?))
39+
}
40+
DataType::Float64 => {
41+
let result = downcast_compute_op!(array, "floor", floor, Float64Array, Int64Array);
42+
Ok(ColumnarValue::Array(result?))
43+
}
44+
DataType::Int64 => {
45+
let result = array.as_any().downcast_ref::<Int64Array>().unwrap();
46+
Ok(ColumnarValue::Array(Arc::new(result.clone())))
47+
}
48+
DataType::Decimal128(_, scale) if *scale > 0 => {
49+
let f = decimal_floor_f(scale);
50+
let (precision, scale) = get_precision_scale(data_type);
51+
make_decimal_array(array, precision, scale, &f)
52+
}
53+
other => Err(DataFusionError::Internal(format!(
54+
"Unsupported data type {:?} for function floor",
55+
other,
56+
))),
57+
},
58+
ColumnarValue::Scalar(a) => match a {
59+
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
60+
a.map(|x| x.floor() as i64),
61+
))),
62+
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
63+
a.map(|x| x.floor() as i64),
64+
))),
65+
ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))),
66+
ScalarValue::Decimal128(a, _, scale) if *scale > 0 => {
67+
let f = decimal_floor_f(scale);
68+
let (precision, scale) = get_precision_scale(data_type);
69+
make_decimal_scalar(a, precision, scale, &f)
70+
}
71+
_ => Err(DataFusionError::Internal(format!(
72+
"Unsupported data type {:?} for function floor",
73+
value.data_type(),
74+
))),
75+
},
76+
}
77+
}
78+
79+
#[inline]
80+
fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 {
81+
let div = 10_i128.pow_wrapping(*scale as u32);
82+
move |x: i128| div_floor(x, div)
83+
}
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)