Skip to content

Commit 40a2055

Browse files
authored
Fix coalesce expr_fn function to take multiple arguments (#10321)
1 parent 89443bf commit 40a2055

File tree

2 files changed

+244
-12
lines changed

2 files changed

+244
-12
lines changed

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion::error::Result;
3030
use datafusion::prelude::*;
3131

3232
use datafusion::assert_batches_eq;
33-
use datafusion_common::DFSchema;
33+
use datafusion_common::{DFSchema, ScalarValue};
3434
use datafusion_expr::expr::Alias;
3535
use datafusion_expr::ExprSchemable;
3636

@@ -161,6 +161,181 @@ async fn test_fn_btrim_with_chars() -> Result<()> {
161161
Ok(())
162162
}
163163

164+
#[tokio::test]
165+
async fn test_fn_nullif() -> Result<()> {
166+
let expr = nullif(col("a"), lit("abcDEF"));
167+
168+
let expected = [
169+
"+-------------------------------+",
170+
"| nullif(test.a,Utf8(\"abcDEF\")) |",
171+
"+-------------------------------+",
172+
"| |",
173+
"| abc123 |",
174+
"| CBAdef |",
175+
"| 123AbcDef |",
176+
"+-------------------------------+",
177+
];
178+
179+
assert_fn_batches!(expr, expected);
180+
181+
Ok(())
182+
}
183+
184+
#[tokio::test]
185+
async fn test_fn_arrow_cast() -> Result<()> {
186+
let expr = arrow_typeof(arrow_cast(col("b"), lit("Float64")));
187+
188+
let expected = [
189+
"+--------------------------------------------------+",
190+
"| arrow_typeof(arrow_cast(test.b,Utf8(\"Float64\"))) |",
191+
"+--------------------------------------------------+",
192+
"| Float64 |",
193+
"| Float64 |",
194+
"| Float64 |",
195+
"| Float64 |",
196+
"+--------------------------------------------------+",
197+
];
198+
199+
assert_fn_batches!(expr, expected);
200+
201+
Ok(())
202+
}
203+
204+
#[tokio::test]
205+
async fn test_nvl() -> Result<()> {
206+
let lit_null = lit(ScalarValue::Utf8(None));
207+
// nvl(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'TURNED_NULL')
208+
let expr = nvl(
209+
when(col("a").eq(lit("abcDEF")), lit_null)
210+
.otherwise(col("a"))
211+
.unwrap(),
212+
lit("TURNED_NULL"),
213+
)
214+
.alias("nvl_expr");
215+
216+
let expected = [
217+
"+-------------+",
218+
"| nvl_expr |",
219+
"+-------------+",
220+
"| TURNED_NULL |",
221+
"| abc123 |",
222+
"| CBAdef |",
223+
"| 123AbcDef |",
224+
"+-------------+",
225+
];
226+
227+
assert_fn_batches!(expr, expected);
228+
229+
Ok(())
230+
}
231+
#[tokio::test]
232+
async fn test_nvl2() -> Result<()> {
233+
let lit_null = lit(ScalarValue::Utf8(None));
234+
// nvl2(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'NON_NUll', 'TURNED_NULL')
235+
let expr = nvl2(
236+
when(col("a").eq(lit("abcDEF")), lit_null)
237+
.otherwise(col("a"))
238+
.unwrap(),
239+
lit("NON_NULL"),
240+
lit("TURNED_NULL"),
241+
)
242+
.alias("nvl2_expr");
243+
244+
let expected = [
245+
"+-------------+",
246+
"| nvl2_expr |",
247+
"+-------------+",
248+
"| TURNED_NULL |",
249+
"| NON_NULL |",
250+
"| NON_NULL |",
251+
"| NON_NULL |",
252+
"+-------------+",
253+
];
254+
255+
assert_fn_batches!(expr, expected);
256+
257+
Ok(())
258+
}
259+
#[tokio::test]
260+
async fn test_fn_arrow_typeof() -> Result<()> {
261+
let expr = arrow_typeof(col("l"));
262+
263+
let expected = [
264+
"+------------------------------------------------------------------------------------------------------------------+",
265+
"| arrow_typeof(test.l) |",
266+
"+------------------------------------------------------------------------------------------------------------------+",
267+
"| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
268+
"| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
269+
"| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
270+
"| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |",
271+
"+------------------------------------------------------------------------------------------------------------------+",
272+
];
273+
274+
assert_fn_batches!(expr, expected);
275+
276+
Ok(())
277+
}
278+
279+
#[tokio::test]
280+
async fn test_fn_struct() -> Result<()> {
281+
let expr = r#struct(vec![col("a"), col("b")]);
282+
283+
let expected = [
284+
"+--------------------------+",
285+
"| struct(test.a,test.b) |",
286+
"+--------------------------+",
287+
"| {c0: abcDEF, c1: 1} |",
288+
"| {c0: abc123, c1: 10} |",
289+
"| {c0: CBAdef, c1: 10} |",
290+
"| {c0: 123AbcDef, c1: 100} |",
291+
"+--------------------------+",
292+
];
293+
294+
assert_fn_batches!(expr, expected);
295+
296+
Ok(())
297+
}
298+
299+
#[tokio::test]
300+
async fn test_fn_named_struct() -> Result<()> {
301+
let expr = named_struct(vec![lit("column_a"), col("a"), lit("column_b"), col("b")]);
302+
303+
let expected = [
304+
"+---------------------------------------------------------------+",
305+
"| named_struct(Utf8(\"column_a\"),test.a,Utf8(\"column_b\"),test.b) |",
306+
"+---------------------------------------------------------------+",
307+
"| {column_a: abcDEF, column_b: 1} |",
308+
"| {column_a: abc123, column_b: 10} |",
309+
"| {column_a: CBAdef, column_b: 10} |",
310+
"| {column_a: 123AbcDef, column_b: 100} |",
311+
"+---------------------------------------------------------------+",
312+
];
313+
314+
assert_fn_batches!(expr, expected);
315+
316+
Ok(())
317+
}
318+
319+
#[tokio::test]
320+
async fn test_fn_coalesce() -> Result<()> {
321+
let expr = coalesce(vec![lit(ScalarValue::Utf8(None)), lit("ab")]);
322+
323+
let expected = [
324+
"+---------------------------------+",
325+
"| coalesce(Utf8(NULL),Utf8(\"ab\")) |",
326+
"+---------------------------------+",
327+
"| ab |",
328+
"| ab |",
329+
"| ab |",
330+
"| ab |",
331+
"+---------------------------------+",
332+
];
333+
334+
assert_fn_batches!(expr, expected);
335+
336+
Ok(())
337+
}
338+
164339
#[tokio::test]
165340
async fn test_fn_approx_median() -> Result<()> {
166341
let expr = approx_median(col("b"));

datafusion/functions/src/core/mod.rs

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
//! "core" DataFusion functions
1919
20+
use datafusion_expr::ScalarUDF;
21+
use std::sync::Arc;
22+
2023
pub mod arrow_cast;
2124
pub mod arrowtypeof;
2225
pub mod coalesce;
@@ -39,14 +42,68 @@ make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field);
3942
make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce);
4043

4144
// Export the functions out of this package, both as expr_fn as well as a list of functions
42-
export_functions!(
43-
(nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."),
44-
(arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."),
45-
(nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"),
46-
(nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."),
47-
(arrow_typeof, arg_1, "Returns the Arrow type of the input expression."),
48-
(r#struct, args, "Returns a struct with the given arguments"),
49-
(named_struct, args, "Returns a struct with the given names and arguments pairs"),
50-
(get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct"),
51-
(coalesce, args, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL")
52-
);
45+
pub mod expr_fn {
46+
use datafusion_expr::Expr;
47+
48+
/// returns NULL if value1 equals value2; otherwise it returns value1. This
49+
/// can be used to perform the inverse operation of the COALESCE expression
50+
pub fn nullif(arg1: Expr, arg2: Expr) -> Expr {
51+
super::nullif().call(vec![arg1, arg2])
52+
}
53+
54+
/// returns value1 cast to the `arrow_type` given the second argument. This
55+
/// can be used to cast to a specific `arrow_type`.
56+
pub fn arrow_cast(arg1: Expr, arg2: Expr) -> Expr {
57+
super::arrow_cast().call(vec![arg1, arg2])
58+
}
59+
60+
/// Returns value2 if value1 is NULL; otherwise it returns value1
61+
pub fn nvl(arg1: Expr, arg2: Expr) -> Expr {
62+
super::nvl().call(vec![arg1, arg2])
63+
}
64+
65+
/// Returns value2 if value1 is not NULL; otherwise, it returns value3.
66+
pub fn nvl2(arg1: Expr, arg2: Expr, arg3: Expr) -> Expr {
67+
super::nvl2().call(vec![arg1, arg2, arg3])
68+
}
69+
70+
/// Returns the Arrow type of the input expression.
71+
pub fn arrow_typeof(arg1: Expr) -> Expr {
72+
super::arrow_typeof().call(vec![arg1])
73+
}
74+
75+
/// Returns a struct with the given arguments
76+
pub fn r#struct(args: Vec<Expr>) -> Expr {
77+
super::r#struct().call(args)
78+
}
79+
80+
/// Returns a struct with the given names and arguments pairs
81+
pub fn named_struct(args: Vec<Expr>) -> Expr {
82+
super::named_struct().call(args)
83+
}
84+
85+
/// Returns the value of the field with the given name from the struct
86+
pub fn get_field(arg1: Expr, arg2: Expr) -> Expr {
87+
super::get_field().call(vec![arg1, arg2])
88+
}
89+
90+
/// Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL
91+
pub fn coalesce(args: Vec<Expr>) -> Expr {
92+
super::coalesce().call(args)
93+
}
94+
}
95+
96+
/// Return a list of all functions in this package
97+
pub fn functions() -> Vec<Arc<ScalarUDF>> {
98+
vec![
99+
nullif(),
100+
arrow_cast(),
101+
nvl(),
102+
nvl2(),
103+
arrow_typeof(),
104+
r#struct(),
105+
named_struct(),
106+
get_field(),
107+
coalesce(),
108+
]
109+
}

0 commit comments

Comments
 (0)