Skip to content

Commit 6ddce2e

Browse files
committed
Use mark join in decorrelate subqueries
This fixes a correctness issue in the current approach.
1 parent 9213260 commit 6ddce2e

File tree

3 files changed

+65
-82
lines changed

3 files changed

+65
-82
lines changed

datafusion/optimizer/src/decorrelate_predicate_subquery.rs

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
//! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins
1919
use std::collections::BTreeSet;
20-
use std::iter;
2120
use std::ops::Deref;
2221
use std::sync::Arc;
2322

@@ -34,11 +33,10 @@ use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
3433
use datafusion_expr::logical_plan::{JoinType, Subquery};
3534
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
3635
use datafusion_expr::{
37-
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
36+
exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
3837
LogicalPlan, LogicalPlanBuilder, Operator,
3938
};
4039

41-
use itertools::chain;
4240
use log::debug;
4341

4442
/// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins
@@ -138,17 +136,14 @@ fn rewrite_inner_subqueries(
138136
Expr::Exists(Exists {
139137
subquery: Subquery { subquery, .. },
140138
negated,
141-
}) => {
142-
match existence_join(&cur_input, Arc::clone(&subquery), None, negated, alias)?
143-
{
144-
Some((plan, exists_expr)) => {
145-
cur_input = plan;
146-
Ok(Transformed::yes(exists_expr))
147-
}
148-
None if negated => Ok(Transformed::no(not_exists(subquery))),
149-
None => Ok(Transformed::no(exists(subquery))),
139+
}) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? {
140+
Some((plan, exists_expr)) => {
141+
cur_input = plan;
142+
Ok(Transformed::yes(exists_expr))
150143
}
151-
}
144+
None if negated => Ok(Transformed::no(not_exists(subquery))),
145+
None => Ok(Transformed::no(exists(subquery))),
146+
},
152147
Expr::InSubquery(InSubquery {
153148
expr,
154149
subquery: Subquery { subquery, .. },
@@ -159,7 +154,7 @@ fn rewrite_inner_subqueries(
159154
.map_or(plan_err!("single expression required."), |output_expr| {
160155
Ok(Expr::eq(*expr.clone(), output_expr))
161156
})?;
162-
match existence_join(
157+
match mark_join(
163158
&cur_input,
164159
Arc::clone(&subquery),
165160
Some(in_predicate),
@@ -283,10 +278,6 @@ fn build_join_top(
283278
build_join(left, subquery, in_predicate_opt, join_type, subquery_alias)
284279
}
285280

286-
/// Existence join is emulated by adding a non-nullable column to the subquery and using a left join
287-
/// and checking if the column is null or not. If native support is added for Existence/Mark then
288-
/// we should use that instead.
289-
///
290281
/// This is used to handle the case when the subquery is embedded in a more complex boolean
291282
/// expression like and OR. For example
292283
///
@@ -296,37 +287,26 @@ fn build_join_top(
296287
///
297288
/// ```text
298289
/// Projection: t1.id
299-
/// Filter: t1.id < 0 OR __correlated_sq_1.__exists IS NOT NULL
290+
/// Filter: t1.id < 0 OR __correlated_sq_1.mark
300291
/// Left Join: Filter: t1.id = __correlated_sq_1.id
301292
/// TableScan: t1
302293
/// SubqueryAlias: __correlated_sq_1
303-
/// Projection: t2.id, true as __exists
294+
/// Projection: t2.id
304295
/// TableScan: t2
305-
fn existence_join(
296+
fn mark_join(
306297
left: &LogicalPlan,
307298
subquery: Arc<LogicalPlan>,
308299
in_predicate_opt: Option<Expr>,
309300
negated: bool,
310301
alias_generator: &Arc<AliasGenerator>,
311302
) -> Result<Option<(LogicalPlan, Expr)>> {
312-
// Add non nullable column to emulate existence join
313-
let always_true_expr = lit(true).alias("__exists");
314-
let cols = chain(
315-
subquery.schema().columns().into_iter().map(Expr::Column),
316-
iter::once(always_true_expr),
317-
);
318-
let subquery = LogicalPlanBuilder::from(subquery).project(cols)?.build()?;
319303
let alias = alias_generator.next("__correlated_sq");
320304

321-
let exists_col = Expr::Column(Column::new(Some(alias.clone()), "__exists"));
322-
let exists_expr = if negated {
323-
exists_col.is_null()
324-
} else {
325-
exists_col.is_not_null()
326-
};
305+
let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark"));
306+
let exists_expr = if negated { !exists_col } else { exists_col };
327307

328308
Ok(
329-
build_join(left, &subquery, in_predicate_opt, JoinType::Left, alias)?
309+
build_join(left, &subquery, in_predicate_opt, JoinType::LeftMark, alias)?
330310
.map(|plan| (plan, exists_expr)),
331311
)
332312
}

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,13 +1056,11 @@ where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)
10561056
----
10571057
logical_plan
10581058
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
1059-
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL
1060-
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
1061-
04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0)
1062-
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1063-
06)--------SubqueryAlias: __correlated_sq_1
1064-
07)----------Projection: t2.t2_id, Boolean(true) AS __exists
1065-
08)------------TableScan: t2 projection=[t2_id]
1059+
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark
1060+
03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0)
1061+
04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1062+
05)------SubqueryAlias: __correlated_sq_1
1063+
06)--------TableScan: t2 projection=[t2_id]
10661064

10671065
query ITI rowsort
10681066
select t1.t1_id,
@@ -1085,13 +1083,12 @@ where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t
10851083
----
10861084
logical_plan
10871085
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
1088-
02)--Filter: t1.t1_id = Int32(11) OR __correlated_sq_1.__exists IS NULL
1089-
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
1090-
04)------Left Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0)
1091-
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1092-
06)--------SubqueryAlias: __correlated_sq_1
1093-
07)----------Projection: CAST(t2.t2_id AS Int64) + Int64(1), Boolean(true) AS __exists
1094-
08)------------TableScan: t2 projection=[t2_id]
1086+
02)--Filter: t1.t1_id = Int32(11) OR NOT __correlated_sq_1.mark
1087+
03)----LeftMark Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0)
1088+
04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1089+
05)------SubqueryAlias: __correlated_sq_1
1090+
06)--------Projection: CAST(t2.t2_id AS Int64) + Int64(1)
1091+
07)----------TableScan: t2 projection=[t2_id]
10951092

10961093
query ITI rowsort
10971094
select t1.t1_id,
@@ -1113,13 +1110,11 @@ where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id)
11131110
----
11141111
logical_plan
11151112
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
1116-
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL
1117-
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
1118-
04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id
1119-
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1120-
06)--------SubqueryAlias: __correlated_sq_1
1121-
07)----------Projection: t2.t2_id, Boolean(true) AS __exists
1122-
08)------------TableScan: t2 projection=[t2_id]
1113+
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark
1114+
03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id
1115+
04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1116+
05)------SubqueryAlias: __correlated_sq_1
1117+
06)--------TableScan: t2 projection=[t2_id]
11231118

11241119
query ITI rowsort
11251120
select t1.t1_id,
@@ -1142,13 +1137,11 @@ where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id)
11421137
----
11431138
logical_plan
11441139
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
1145-
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NULL
1146-
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
1147-
04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id
1148-
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1149-
06)--------SubqueryAlias: __correlated_sq_1
1150-
07)----------Projection: t2.t2_id, Boolean(true) AS __exists
1151-
08)------------TableScan: t2 projection=[t2_id]
1140+
02)--Filter: t1.t1_id > Int32(40) OR NOT __correlated_sq_1.mark
1141+
03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id
1142+
04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1143+
05)------SubqueryAlias: __correlated_sq_1
1144+
06)--------TableScan: t2 projection=[t2_id]
11521145

11531146
query ITI rowsort
11541147
select t1.t1_id,
@@ -1170,16 +1163,14 @@ where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (s
11701163
----
11711164
logical_plan
11721165
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
1173-
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.__exists IS NOT NULL
1174-
03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_2.__exists
1175-
04)------Left Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0)
1176-
05)--------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id
1177-
06)----------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1178-
07)----------SubqueryAlias: __correlated_sq_1
1179-
08)------------TableScan: t3 projection=[t3_id]
1180-
09)--------SubqueryAlias: __correlated_sq_2
1181-
10)----------Projection: t2.t2_id, Boolean(true) AS __exists
1182-
11)------------TableScan: t2 projection=[t2_id]
1166+
02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.mark
1167+
03)----LeftMark Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0)
1168+
04)------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id
1169+
05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
1170+
06)--------SubqueryAlias: __correlated_sq_1
1171+
07)----------TableScan: t3 projection=[t3_id]
1172+
08)------SubqueryAlias: __correlated_sq_2
1173+
09)--------TableScan: t2 projection=[t2_id]
11831174

11841175
query ITI rowsort
11851176
select t1.t1_id,
@@ -1192,6 +1183,18 @@ where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (s
11921183
22 b 2
11931184
44 d 4
11941185

1186+
# Handle duplicate values in exists query
1187+
query ITI rowsort
1188+
select t1.t1_id,
1189+
t1.t1_name,
1190+
t1.t1_int
1191+
from t1
1192+
where t1.t1_id > 40 or exists (select * from t2 cross join t3 where t1.t1_id = t2.t2_id)
1193+
----
1194+
11 a 1
1195+
22 b 2
1196+
44 d 4
1197+
11951198
# Nested subqueries
11961199
query ITI rowsort
11971200
select t1.t1_id,

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -473,15 +473,15 @@ async fn roundtrip_inlist_5() -> Result<()> {
473473
// on roundtrip there is an additional projection during TableScan which includes all column of the table,
474474
// using assert_expected_plan here as a workaround
475475
assert_expected_plan(
476-
"SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))",
477-
"Projection: data.a, data.f\
478-
\n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR Boolean(true) IS NOT NULL\
479-
\n Projection: data.a, data.f, Boolean(true)\
480-
\n Left Join: data.a = data2.a\
481-
\n TableScan: data projection=[a, f]\
482-
\n Projection: data2.a, Boolean(true)\
483-
\n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\
484-
\n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]",
476+
"SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))",
477+
478+
"Projection: data.a, data.f\
479+
\n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data2.mark\
480+
\n LeftMark Join: data.a = data2.a\
481+
\n TableScan: data projection=[a, f]\
482+
\n Projection: data2.a\
483+
\n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\
484+
\n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]",
485485
true).await
486486
}
487487

0 commit comments

Comments
 (0)