Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 5 additions & 47 deletions datafusion/spark/src/function/bitwise/bitwise_not.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,11 @@ impl ScalarUDFImpl for SparkBitwiseNot {
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
if args.arg_fields.len() != 1 {
return plan_err!("bitwise_not expects exactly 1 argument");
}

let input_field = &args.arg_fields[0];

let out_dt = input_field.data_type().clone();
let mut out_nullable = input_field.is_nullable();

let scalar_null_present = args
.scalar_arguments
.iter()
.any(|opt_s| opt_s.is_some_and(|sv| sv.is_null()));

if scalar_null_present {
out_nullable = true;
}

Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
Ok(Arc::new(Field::new(
self.name(),
args.arg_fields[0].data_type().clone(),
args.arg_fields[0].is_nullable(),
)))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -196,32 +182,4 @@ mod tests {
assert!(out_i64_null.is_nullable());
assert_eq!(out_i64_null.data_type(), &DataType::Int64);
}

#[test]
fn test_bitwise_not_nullability_with_null_scalar() -> Result<()> {
use arrow::datatypes::{DataType, Field};
use datafusion_common::ScalarValue;
use std::sync::Arc;

let func = SparkBitwiseNot::new();

let non_nullable: FieldRef = Arc::new(Field::new("col", DataType::Int32, false));

let out = func.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&non_nullable)],
scalar_arguments: &[None],
})?;
assert!(!out.is_nullable());
assert_eq!(out.data_type(), &DataType::Int32);

let null_scalar = ScalarValue::Int32(None);
let out_with_null_scalar = func.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&non_nullable)],
scalar_arguments: &[Some(&null_scalar)],
})?;
assert!(out_with_null_scalar.is_nullable());
assert_eq!(out_with_null_scalar.data_type(), &DataType::Int32);

Ok(())
}
}
29 changes: 1 addition & 28 deletions datafusion/spark/src/function/datetime/date_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,7 @@ impl ScalarUDFImpl for SparkDateAdd {
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let nullable = args.arg_fields.iter().any(|f| f.is_nullable())
|| args
.scalar_arguments
.iter()
.any(|arg| matches!(arg, Some(sv) if sv.is_null()));

let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
Ok(Arc::new(Field::new(
self.name(),
DataType::Date32,
Expand Down Expand Up @@ -155,7 +150,6 @@ fn spark_date_add(args: &[ArrayRef]) -> Result<ArrayRef> {
mod tests {
use super::*;
use arrow::datatypes::Field;
use datafusion_common::ScalarValue;

#[test]
fn test_date_add_non_nullable_inputs() {
Expand Down Expand Up @@ -194,25 +188,4 @@ mod tests {
assert_eq!(ret_field.data_type(), &DataType::Date32);
assert!(ret_field.is_nullable());
}

#[test]
fn test_date_add_null_scalar() {
let func = SparkDateAdd::new();
let args = &[
Arc::new(Field::new("date", DataType::Date32, false)),
Arc::new(Field::new("num", DataType::Int32, false)),
];

let null_scalar = ScalarValue::Int32(None);

let ret_field = func
.return_field_from_args(ReturnFieldArgs {
arg_fields: args,
scalar_arguments: &[None, Some(&null_scalar)],
})
.unwrap();

assert_eq!(ret_field.data_type(), &DataType::Date32);
assert!(ret_field.is_nullable());
}
}
26 changes: 1 addition & 25 deletions datafusion/spark/src/function/datetime/date_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,7 @@ impl ScalarUDFImpl for SparkDateSub {
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let nullable = args.arg_fields.iter().any(|f| f.is_nullable())
|| args
.scalar_arguments
.iter()
.any(|arg| matches!(arg, Some(sv) if sv.is_null()));

let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
Ok(Arc::new(Field::new(
self.name(),
DataType::Date32,
Expand Down Expand Up @@ -152,7 +147,6 @@ fn spark_date_sub(args: &[ArrayRef]) -> Result<ArrayRef> {
#[cfg(test)]
mod tests {
use super::*;
use datafusion_common::ScalarValue;

#[test]
fn test_date_sub_nullability_non_nullable_args() {
Expand Down Expand Up @@ -187,22 +181,4 @@ mod tests {
assert!(result.is_nullable());
assert_eq!(result.data_type(), &DataType::Date32);
}

#[test]
fn test_date_sub_nullability_scalar_null_argument() {
let udf = SparkDateSub::new();
let date_field = Arc::new(Field::new("d", DataType::Date32, false));
let days_field = Arc::new(Field::new("n", DataType::Int32, false));
let null_scalar = ScalarValue::Int32(None);

let result = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[date_field, days_field],
scalar_arguments: &[None, Some(&null_scalar)],
})
.unwrap();

assert!(result.is_nullable());
assert_eq!(result.data_type(), &DataType::Date32);
}
}