Skip to content

Commit b8fab5c

Browse files
authored
Replace GetFieldAccess with indexing function in SqlToRel (#10375)
* use func in parser Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add tests Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * rm test1 Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * parser done Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fmt Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fix exprapi test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fix test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fix conflicts Signed-off-by: jayzhan211 <jayzhan211@gmail.com> --------- Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
1 parent 18fc376 commit b8fab5c

File tree

5 files changed

+172
-50
lines changed

5 files changed

+172
-50
lines changed

datafusion/core/tests/expr_api/mod.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ fn test_eq_with_coercion() {
6060

6161
#[test]
6262
fn test_get_field() {
63-
// field access Expr::field() requires a rewrite to work
6463
evaluate_expr_test(
65-
col("props").field("a"),
64+
get_field(col("props"), lit("a")),
6665
vec![
6766
"+------------+",
6867
"| expr |",
@@ -77,11 +76,8 @@ fn test_get_field() {
7776

7877
#[test]
7978
fn test_nested_get_field() {
80-
// field access Expr::field() requires a rewrite to work, test when it is
81-
// not the root expression
8279
evaluate_expr_test(
83-
col("props")
84-
.field("a")
80+
get_field(col("props"), lit("a"))
8581
.eq(lit("2021-02-02"))
8682
.or(col("id").eq(lit(1))),
8783
vec![
@@ -98,9 +94,8 @@ fn test_nested_get_field() {
9894

9995
#[test]
10096
fn test_list() {
101-
// list access also requires a rewrite to work
10297
evaluate_expr_test(
103-
col("list").index(lit(1i64)),
98+
array_element(col("list"), lit(1i64)),
10499
vec![
105100
"+------+", "| expr |", "+------+", "| one |", "| two |", "| five |",
106101
"+------+",
@@ -110,9 +105,8 @@ fn test_list() {
110105

111106
#[test]
112107
fn test_list_range() {
113-
// range access also requires a rewrite to work
114108
evaluate_expr_test(
115-
col("list").range(lit(1i64), lit(2i64)),
109+
array_slice(col("list"), lit(1i64), lit(2i64), None),
116110
vec![
117111
"+--------------+",
118112
"| expr |",

datafusion/functions-array/src/rewrite.rs

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,14 @@
1919
2020
use crate::array_has::array_has_all;
2121
use crate::concat::{array_append, array_concat, array_prepend};
22-
use crate::extract::{array_element, array_slice};
2322
use datafusion_common::config::ConfigOptions;
2423
use datafusion_common::tree_node::Transformed;
2524
use datafusion_common::utils::list_ndims;
2625
use datafusion_common::Result;
2726
use datafusion_common::{Column, DFSchema};
2827
use datafusion_expr::expr::ScalarFunction;
2928
use datafusion_expr::expr_rewriter::FunctionRewrite;
30-
use datafusion_expr::{BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Operator};
31-
use datafusion_functions::expr_fn::get_field;
29+
use datafusion_expr::{BinaryExpr, Expr, Operator};
3230

3331
/// Rewrites expressions into function calls to array functions
3432
pub(crate) struct ArrayFunctionRewriter {}
@@ -148,31 +146,6 @@ impl FunctionRewrite for ArrayFunctionRewriter {
148146
Transformed::yes(array_prepend(*left, *right))
149147
}
150148

151-
Expr::GetIndexedField(GetIndexedField {
152-
expr,
153-
field: GetFieldAccess::NamedStructField { name },
154-
}) => {
155-
let name = Expr::Literal(name);
156-
Transformed::yes(get_field(*expr, name))
157-
}
158-
159-
// expr[idx] ==> array_element(expr, idx)
160-
Expr::GetIndexedField(GetIndexedField {
161-
expr,
162-
field: GetFieldAccess::ListIndex { key },
163-
}) => Transformed::yes(array_element(*expr, *key)),
164-
165-
// expr[start, stop, stride] ==> array_slice(expr, start, stop, stride)
166-
Expr::GetIndexedField(GetIndexedField {
167-
expr,
168-
field:
169-
GetFieldAccess::ListRange {
170-
start,
171-
stop,
172-
stride,
173-
},
174-
}) => Transformed::yes(array_slice(*expr, *start, *stop, Some(*stride))),
175-
176149
_ => Transformed::no(expr),
177150
};
178151
Ok(transformed)

datafusion/sql/src/expr/identifier.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
1919
use arrow_schema::Field;
2020
use datafusion_common::{
2121
internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result,
22-
TableReference,
22+
ScalarValue, TableReference,
2323
};
24-
use datafusion_expr::{Case, Expr};
24+
use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr};
2525
use sqlparser::ast::{Expr as SQLExpr, Ident};
2626

2727
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
@@ -133,7 +133,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
133133
);
134134
}
135135
let nested_name = nested_names[0].to_string();
136-
Ok(Expr::Column(Column::from((qualifier, field))).field(nested_name))
136+
137+
let col = Expr::Column(Column::from((qualifier, field)));
138+
if let Some(udf) =
139+
self.context_provider.get_function_meta("get_field")
140+
{
141+
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
142+
udf,
143+
vec![col, lit(ScalarValue::from(nested_name))],
144+
)))
145+
} else {
146+
internal_err!("get_field not found")
147+
}
137148
}
138149
// found matching field with no spare identifier(s)
139150
Some((field, qualifier, _nested_names)) => {

datafusion/sql/src/expr/mod.rs

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use datafusion_expr::expr::InList;
2929
use datafusion_expr::expr::ScalarFunction;
3030
use datafusion_expr::{
3131
col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable,
32-
GetFieldAccess, GetIndexedField, Like, Literal, Operator, TryCast,
32+
GetFieldAccess, Like, Literal, Operator, TryCast,
3333
};
3434

3535
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
@@ -1019,10 +1019,48 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
10191019
expr
10201020
};
10211021

1022-
Ok(Expr::GetIndexedField(GetIndexedField::new(
1023-
Box::new(expr),
1024-
self.plan_indices(indices, schema, planner_context)?,
1025-
)))
1022+
let field = self.plan_indices(indices, schema, planner_context)?;
1023+
match field {
1024+
GetFieldAccess::NamedStructField { name } => {
1025+
if let Some(udf) = self.context_provider.get_function_meta("get_field") {
1026+
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
1027+
udf,
1028+
vec![expr, lit(name)],
1029+
)))
1030+
} else {
1031+
internal_err!("get_field not found")
1032+
}
1033+
}
1034+
// expr[idx] ==> array_element(expr, idx)
1035+
GetFieldAccess::ListIndex { key } => {
1036+
if let Some(udf) =
1037+
self.context_provider.get_function_meta("array_element")
1038+
{
1039+
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
1040+
udf,
1041+
vec![expr, *key],
1042+
)))
1043+
} else {
1044+
internal_err!("get_field not found")
1045+
}
1046+
}
1047+
// expr[start, stop, stride] ==> array_slice(expr, start, stop, stride)
1048+
GetFieldAccess::ListRange {
1049+
start,
1050+
stop,
1051+
stride,
1052+
} => {
1053+
if let Some(udf) = self.context_provider.get_function_meta("array_slice")
1054+
{
1055+
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
1056+
udf,
1057+
vec![expr, *start, *stop, *stride],
1058+
)))
1059+
} else {
1060+
internal_err!("array_slice not found")
1061+
}
1062+
}
1063+
}
10261064
}
10271065
}
10281066

datafusion/sqllogictest/test_files/expr.slt

Lines changed: 110 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,28 +2324,134 @@ host3 3.3
23242324

23252325
# can have an aggregate function with an inner CASE WHEN
23262326
query TR
2327-
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
2327+
select
2328+
t2.server_host as host,
2329+
sum((
2330+
case when t2.server_host is not null
2331+
then t2.server_load2
2332+
end
2333+
))
2334+
from (
2335+
select
2336+
struct(time,load1,load2,host)['c2'] as server_load2,
2337+
struct(time,load1,load2,host)['c3'] as server_host
2338+
from t1
2339+
) t2
2340+
where server_host IS NOT NULL
2341+
group by server_host order by host;
23282342
----
23292343
host1 101
23302344
host2 202
23312345
host3 303
23322346

2347+
# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
2348+
query error
2349+
select
2350+
t2.server['c3'] as host,
2351+
sum((
2352+
case when t2.server['c3'] is not null
2353+
then t2.server['c2']
2354+
end
2355+
))
2356+
from (
2357+
select
2358+
struct(time,load1,load2,host) as server
2359+
from t1
2360+
) t2
2361+
where t2.server['c3'] IS NOT NULL
2362+
group by t2.server['c3'] order by host;
2363+
23332364
# can have 2 projections with aggr(short_circuited), with different short-circuited expr
23342365
query TRR
2335-
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
2366+
select
2367+
t2.server_host as host,
2368+
sum(coalesce(server_load1)),
2369+
sum((
2370+
case when t2.server_host is not null
2371+
then t2.server_load2
2372+
end
2373+
))
2374+
from (
2375+
select
2376+
struct(time,load1,load2,host)['c1'] as server_load1,
2377+
struct(time,load1,load2,host)['c2'] as server_load2,
2378+
struct(time,load1,load2,host)['c3'] as server_host
2379+
from t1
2380+
) t2
2381+
where server_host IS NOT NULL
2382+
group by server_host order by host;
23362383
----
23372384
host1 1.1 101
23382385
host2 2.2 202
23392386
host3 3.3 303
23402387

2341-
# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN)
2388+
# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
2389+
query error
2390+
select
2391+
t2.server['c3'] as host,
2392+
sum(coalesce(server['c1'])),
2393+
sum((
2394+
case when t2.server['c3'] is not null
2395+
then t2.server['c2']
2396+
end
2397+
))
2398+
from (
2399+
select
2400+
struct(time,load1,load2,host) as server,
2401+
from t1
2402+
) t2
2403+
where server_host IS NOT NULL
2404+
group by server_host order by host;
2405+
23422406
query TRR
2343-
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;
2407+
select
2408+
t2.server_host as host,
2409+
sum((
2410+
case when t2.server_host is not null
2411+
then server_load1
2412+
end
2413+
)),
2414+
sum((
2415+
case when server_host is not null
2416+
then server_load2
2417+
end
2418+
))
2419+
from (
2420+
select
2421+
struct(time,load1,load2,host)['c1'] as server_load1,
2422+
struct(time,load1,load2,host)['c2'] as server_load2,
2423+
struct(time,load1,load2,host)['c3'] as server_host
2424+
from t1
2425+
) t2
2426+
where server_host IS NOT NULL
2427+
group by server_host order by host;
23442428
----
23452429
host1 1.1 101
23462430
host2 2.2 202
23472431
host3 3.3 303
23482432

2433+
# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364
2434+
query error
2435+
select
2436+
t2.server['c3'] as host,
2437+
sum((
2438+
case when t2.server['c3'] is not null
2439+
then t2.server['c1']
2440+
end
2441+
)),
2442+
sum((
2443+
case when t2.server['c3'] is not null
2444+
then t2.server['c2']
2445+
end
2446+
))
2447+
from (
2448+
select
2449+
struct(time,load1,load2,host) as server
2450+
from t1
2451+
) t2
2452+
where t2.server['c3'] IS NOT NULL
2453+
group by t2.server['c3'] order by host;
2454+
23492455
# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce)
23502456
query TRR
23512457
select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;

0 commit comments

Comments
 (0)