Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ use crate::scalar_funcs::hash_expressions::{
spark_sha224, spark_sha256, spark_sha384, spark_sha512,
};
use crate::scalar_funcs::{
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_round, spark_unhex,
spark_unscaled_value, spark_xxhash64, SparkChrFunc,
spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
spark_murmur3_hash, spark_round, spark_unhex, spark_unscaled_value, spark_xxhash64,
SparkChrFunc,
};
use crate::spark_read_side_padding;
use crate::{spark_date_add, spark_date_sub, spark_read_side_padding};
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::registry::FunctionRegistry;
Expand Down
102 changes: 102 additions & 0 deletions native/spark-expr/src/datetime_funcs/date_arithmetic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{ArrayRef, AsArray};
use arrow::compute::kernels::numeric::{add, sub};
use arrow::datatypes::IntervalDayTime;
use arrow_array::builder::IntervalDayTimeBuilder;
use arrow_array::types::{Int16Type, Int32Type, Int8Type};
use arrow_array::{Array, Datum};
use arrow_schema::{ArrowError, DataType};
use datafusion::physical_expr_common::datum;
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use std::sync::Arc;

macro_rules! scalar_date_arithmetic {
($start:expr, $days:expr, $op:expr) => {{
let interval = IntervalDayTime::new(*$days as i32, 0);
let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
datum::apply($start, &interval_cv, $op)
}};
}
macro_rules! array_date_arithmetic {
($days:expr, $interval_builder:expr, $intType:ty) => {{
for day in $days.as_primitive::<$intType>().into_iter() {
if let Some(non_null_day) = day {
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
} else {
$interval_builder.append_null();
}
}
}};
}

/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second
/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the
/// second argument and use DataFusion's interface to apply Arrow's operators.
fn spark_date_arithmetic(
args: &[ColumnarValue],
op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
) -> Result<ColumnarValue, DataFusionError> {
let start = &args[0];
match &args[1] {
ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Array(days) => {
let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len());
match days.data_type() {
DataType::Int8 => {
array_date_arithmetic!(days, interval_builder, Int8Type)
}
DataType::Int16 => {
array_date_arithmetic!(days, interval_builder, Int16Type)
}
DataType::Int32 => {
array_date_arithmetic!(days, interval_builder, Int32Type)
}
_ => {
return Err(DataFusionError::Internal(format!(
"Unsupported data types {:?} for date arithmetic.",
args,
)))
}
}
let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish()));
datum::apply(start, &interval_cv, op)
}
_ => Err(DataFusionError::Internal(format!(
"Unsupported data types {:?} for date arithmetic.",
args,
))),
}
}

pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_date_arithmetic(args, add)
}

pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_date_arithmetic(args, sub)
}
113 changes: 113 additions & 0 deletions native/spark-expr/src/datetime_funcs/date_trunc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::record_batch::RecordBatch;
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue::Utf8};
use datafusion_physical_expr::PhysicalExpr;
use std::hash::Hash;
use std::{
any::Any,
fmt::{Debug, Display, Formatter},
sync::Arc,
};

use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn};

#[derive(Debug, Eq)]
pub struct DateTruncExpr {
/// An array with DataType::Date32
child: Arc<dyn PhysicalExpr>,
/// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc
format: Arc<dyn PhysicalExpr>,
}

impl Hash for DateTruncExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.child.hash(state);
self.format.hash(state);
}
}
impl PartialEq for DateTruncExpr {
fn eq(&self, other: &Self) -> bool {
self.child.eq(&other.child) && self.format.eq(&other.format)
}
}

impl DateTruncExpr {
pub fn new(child: Arc<dyn PhysicalExpr>, format: Arc<dyn PhysicalExpr>) -> Self {
DateTruncExpr { child, format }
}
}

impl Display for DateTruncExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"DateTrunc [child:{}, format: {}]",
self.child, self.format
)
}
}

impl PhysicalExpr for DateTruncExpr {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result<DataType> {
self.child.data_type(input_schema)
}

fn nullable(&self, _: &Schema) -> datafusion_common::Result<bool> {
Ok(true)
}

fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
let date = self.child.evaluate(batch)?;
let format = self.format.evaluate(batch)?;
match (date, format) {
(ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => {
let result = date_trunc_dyn(&date, format)?;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Array(date), ColumnarValue::Array(formats)) => {
let result = date_trunc_array_fmt_dyn(&date, &formats)?;
Ok(ColumnarValue::Array(result))
}
_ => Err(DataFusionError::Execution(
"Invalid input to function DateTrunc. Expected (PrimitiveArray<Date32>, Scalar) or \
(PrimitiveArray<Date32>, StringArray)".to_string(),
)),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.child]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
Ok(Arc::new(DateTruncExpr::new(
Arc::clone(&children[0]),
Arc::clone(&self.format),
)))
}
}
122 changes: 122 additions & 0 deletions native/spark-expr/src/datetime_funcs/hour.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::utils::array_with_timezone;
use arrow::{
compute::{date_part, DatePart},
record_batch::RecordBatch,
};
use arrow_schema::{DataType, Schema, TimeUnit::Microsecond};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::DataFusionError;
use datafusion_physical_expr::PhysicalExpr;
use std::hash::Hash;
use std::{
any::Any,
fmt::{Debug, Display, Formatter},
sync::Arc,
};

#[derive(Debug, Eq)]
pub struct HourExpr {
/// An array with DataType::Timestamp(TimeUnit::Microsecond, None)
child: Arc<dyn PhysicalExpr>,
timezone: String,
}

impl Hash for HourExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.child.hash(state);
self.timezone.hash(state);
}
}
impl PartialEq for HourExpr {
fn eq(&self, other: &Self) -> bool {
self.child.eq(&other.child) && self.timezone.eq(&other.timezone)
}
}

impl HourExpr {
pub fn new(child: Arc<dyn PhysicalExpr>, timezone: String) -> Self {
HourExpr { child, timezone }
}
}

impl Display for HourExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Hour [timezone:{}, child: {}]",
self.timezone, self.child
)
}
}

impl PhysicalExpr for HourExpr {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result<DataType> {
match self.child.data_type(input_schema).unwrap() {
DataType::Dictionary(key_type, _) => {
Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32)))
}
_ => Ok(DataType::Int32),
}
}

fn nullable(&self, _: &Schema) -> datafusion_common::Result<bool> {
Ok(true)
}

fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
let arg = self.child.evaluate(batch)?;
match arg {
ColumnarValue::Array(array) => {
let array = array_with_timezone(
array,
self.timezone.clone(),
Some(&DataType::Timestamp(
Microsecond,
Some(self.timezone.clone().into()),
)),
)?;
let result = date_part(&array, DatePart::Hour)?;

Ok(ColumnarValue::Array(result))
}
_ => Err(DataFusionError::Execution(
"Hour(scalar) should be fold in Spark JVM side.".to_string(),
)),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.child]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
Ok(Arc::new(HourExpr::new(
Arc::clone(&children[0]),
self.timezone.clone(),
)))
}
}
Loading
Loading