diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 5444b3a88f05..fc0a4e7c7ed2 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -60,6 +60,9 @@ async fn test_mathematical_expressions_with_null() -> Result<()> { test_expression!("atan2(NULL, NULL)", "NULL"); test_expression!("atan2(1, NULL)", "NULL"); test_expression!("atan2(NULL, 1)", "NULL"); + test_expression!("nanvl(NULL, NULL)", "NULL"); + test_expression!("nanvl(1, NULL)", "NULL"); + test_expression!("nanvl(NULL, 1)", "NULL"); Ok(()) } diff --git a/datafusion/core/tests/sqllogictests/test_files/math.slt b/datafusion/core/tests/sqllogictests/test_files/math.slt index 152e8b78bdfa..fc27333ec0af 100644 --- a/datafusion/core/tests/sqllogictests/test_files/math.slt +++ b/datafusion/core/tests/sqllogictests/test_files/math.slt @@ -93,3 +93,9 @@ query RRRRRRR SELECT atan2(2.0, 1.0), atan2(-2.0, 1.0), atan2(2.0, -1.0), atan2(-2.0, -1.0), atan2(NULL, 1.0), atan2(2.0, NULL), atan2(NULL, NULL); ---- 1.107148717794 -1.107148717794 2.034443935796 -2.034443935796 NULL NULL NULL + +# nanvl +query RRR +SELECT nanvl(asin(10), 1.0), nanvl(1.0, 2.0), nanvl(asin(10), asin(10)) +---- +1 1 NaN \ No newline at end of file diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt b/datafusion/core/tests/sqllogictests/test_files/scalar.slt index d5ce7737fba0..80f5bd6c9d78 100644 --- a/datafusion/core/tests/sqllogictests/test_files/scalar.slt +++ b/datafusion/core/tests/sqllogictests/test_files/scalar.slt @@ -660,6 +660,41 @@ select round(log2(a), 5), round(log2(b), 5), round(log2(c), 5) from signed_integ NaN 13.28771 NaN NaN 6.64386 NaN +## nanvl + +# nanvl scalar function +query RRR rowsort +select nanvl(0, 1), nanvl(asin(10), 2), nanvl(3, asin(10)); +---- +0 2 3 + +# nanvl scalar nulls +query R rowsort +select nanvl(null, 64); +---- +NULL + +# nanvl scalar nulls #1 +query R rowsort +select nanvl(2, null); +---- +NULL + +# nanvl scalar nulls #2 +query R rowsort +select nanvl(null, null); +---- +NULL + +# nanvl with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(nanvl(asin(f + a), 2), 5), round(nanvl(asin(b + c), 3), 5), round(nanvl(asin(d + e), 4), 5) from small_floats; +---- +0.7754 1.11977 -0.9273 +2 -0.20136 0.7754 +2 -1.11977 4 +NULL NULL NULL + ## pi # pi scalar function diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 33db0f9eb1a4..535ac84457a6 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -89,6 +89,8 @@ pub enum BuiltinScalarFunction { Log10, /// log2 Log2, + /// nanvl + Nanvl, /// pi Pi, /// power @@ -328,6 +330,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Log10 => Volatility::Immutable, BuiltinScalarFunction::Log2 => Volatility::Immutable, + BuiltinScalarFunction::Nanvl => Volatility::Immutable, BuiltinScalarFunction::Pi => Volatility::Immutable, BuiltinScalarFunction::Power => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, @@ -760,6 +763,11 @@ impl BuiltinScalarFunction { _ => Ok(Float64), }, + BuiltinScalarFunction::Nanvl => match &input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), BuiltinScalarFunction::Abs @@ -1120,6 +1128,10 @@ impl BuiltinScalarFunction { ], self.volatility(), ), + BuiltinScalarFunction::Nanvl => Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + self.volatility(), + ), BuiltinScalarFunction::Factorial => { Signature::uniform(1, vec![Int64], self.volatility()) } @@ -1193,6 +1205,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Log => &["log"], BuiltinScalarFunction::Log10 => &["log10"], BuiltinScalarFunction::Log2 => &["log2"], + BuiltinScalarFunction::Nanvl => &["nanvl"], BuiltinScalarFunction::Pi => &["pi"], BuiltinScalarFunction::Power => &["power", "pow"], BuiltinScalarFunction::Radians => &["radians"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 7c769490af29..4a59e92999d5 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -798,6 +798,7 @@ scalar_expr!( scalar_expr!(CurrentDate, current_date, ,"returns current UTC date as a [`DataType::Date32`] value"); scalar_expr!(Now, now, ,"returns current timestamp in nanoseconds, using the same value for all instances of now() in same statement"); scalar_expr!(CurrentTime, current_time, , "returns current UTC time as a [`DataType::Time64`] value"); +scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); @@ -989,6 +990,7 @@ mod test { test_unary_scalar_expr!(Log10, log10); test_unary_scalar_expr!(Ln, ln); test_scalar_expr!(Atan2, atan2, y, x); + test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Ascii, ascii, input); test_scalar_expr!(BitLength, bit_length, string); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c9683d2cdbc9..5e7d7e566e5d 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -380,6 +380,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), + BuiltinScalarFunction::Nanvl => { + Arc::new(|args| make_scalar_function(math_expressions::nanvl)(args)) + } BuiltinScalarFunction::Radians => Arc::new(math_expressions::to_radians), BuiltinScalarFunction::Random => Arc::new(math_expressions::random), BuiltinScalarFunction::Round => { diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 883c016c047b..03e0bb64551b 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -259,6 +259,53 @@ pub fn lcm(args: &[ArrayRef]) -> Result { } } +/// Nanvl SQL function +pub fn nanvl(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => { + let compute_nanvl = |x: f64, y: f64| { + if x.is_nan() { + y + } else { + x + } + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Float64Array, + { compute_nanvl } + )) as ArrayRef) + } + + DataType::Float32 => { + let compute_nanvl = |x: f32, y: f32| { + if x.is_nan() { + y + } else { + x + } + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Float32Array, + { compute_nanvl } + )) as ArrayRef) + } + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function nanvl" + ))), + } +} + /// Pi SQL function pub fn pi(args: &[ColumnarValue]) -> Result { if !matches!(&args[0], ColumnarValue::Array(_)) { @@ -958,4 +1005,40 @@ mod tests { assert_eq!(floats.value(3), 123.0); assert_eq!(floats.value(4), -321.0); } + + #[test] + fn test_nanvl_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y + Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x + ]; + + let result = nanvl(&args).expect("failed to initialize function atan2"); + let floats = + as_float64_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 1.0); + assert_eq!(floats.value(1), 6.0); + assert_eq!(floats.value(2), 3.0); + assert!(floats.value(3).is_nan()); + } + + #[test] + fn test_nanvl_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y + Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x + ]; + + let result = nanvl(&args).expect("failed to initialize function atan2"); + let floats = + as_float32_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 1.0); + assert_eq!(floats.value(1), 6.0); + assert_eq!(floats.value(2), 3.0); + assert!(floats.value(3).is_nan()); + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9694a5beb7e4..19dace31328d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -577,6 +577,7 @@ enum ScalarFunction { ArrayReplaceN = 108; ArrayRemoveAll = 109; ArrayReplaceAll = 110; + Nanvl = 111; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 40f58b312acd..dbdca6f28251 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -18284,6 +18284,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayReplaceN => "ArrayReplaceN", Self::ArrayRemoveAll => "ArrayRemoveAll", Self::ArrayReplaceAll => "ArrayReplaceAll", + Self::Nanvl => "Nanvl", }; serializer.serialize_str(variant) } @@ -18405,6 +18406,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplaceN", "ArrayRemoveAll", "ArrayReplaceAll", + "Nanvl", ]; struct GeneratedVisitor; @@ -18557,6 +18559,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplaceN" => Ok(ScalarFunction::ArrayReplaceN), "ArrayRemoveAll" => Ok(ScalarFunction::ArrayRemoveAll), "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), + "Nanvl" => Ok(ScalarFunction::Nanvl), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7e4a5f8afdf8..605bc2033e1c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2310,6 +2310,7 @@ pub enum ScalarFunction { ArrayReplaceN = 108, ArrayRemoveAll = 109, ArrayReplaceAll = 110, + Nanvl = 111, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2428,6 +2429,7 @@ impl ScalarFunction { ScalarFunction::ArrayReplaceN => "ArrayReplaceN", ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll", ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", + ScalarFunction::Nanvl => "Nanvl", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2543,6 +2545,7 @@ impl ScalarFunction { "ArrayReplaceN" => Some(Self::ArrayReplaceN), "ArrayRemoveAll" => Some(Self::ArrayRemoveAll), "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), + "Nanvl" => Some(Self::Nanvl), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 4caff5fba060..86d2a683cfe8 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -46,10 +46,10 @@ use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, now, nullif, octet_length, pi, power, radians, random, - regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, - sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, - strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians, + random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, + rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, + starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, trim, trim_array, trunc, upper, uuid, window_frame::regularize, @@ -522,6 +522,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::StructFun => Self::Struct, ScalarFunction::FromUnixtime => Self::FromUnixtime, ScalarFunction::Atan2 => Self::Atan2, + ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, } } @@ -1527,6 +1528,10 @@ pub fn parse_expr( ScalarFunction::CurrentDate => Ok(current_date()), ScalarFunction::CurrentTime => Ok(current_time()), ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry)?)), + ScalarFunction::Nanvl => Ok(nanvl( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", )), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 3f4fdfeb7486..e90ba317b145 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1474,6 +1474,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Struct => Self::StructFun, BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, BuiltinScalarFunction::Atan2 => Self::Atan2, + BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, }; diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 1e90edc1124d..d8337832d464 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -43,6 +43,7 @@ - [log](#log) - [log10](#log10) - [log2](#log2) +- [nanvl](#nanvl) - [pi](#pi) - [power](#power) - [pow](#pow) @@ -353,6 +354,22 @@ log2(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +### `nanvl` + +Returns the first argument if it's not _NaN_. +Returns the second argument otherwise. + +``` +nanvl(expression_x, expression_y) +``` + +#### Arguments + +- **expression_x**: Numeric expression to return if it's not _NaN_. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Numeric expression to return if the first expression is _NaN_. + Can be a constant, column, or function, and any combination of arithmetic operators. + ### `pi` Returns an approximate value of π.