Skip to content

Commit f33a843

Browse files
committed
Detect when filters make subqueries scalar
1 parent 9619f02 commit f33a843

File tree

3 files changed

+188
-3
lines changed

3 files changed

+188
-3
lines changed

datafusion/common/src/functional_dependencies.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,14 @@ impl FunctionalDependencies {
413413
}
414414
}
415415

416+
impl Deref for FunctionalDependencies {
417+
type Target = [FunctionalDependence];
418+
419+
fn deref(&self) -> &Self::Target {
420+
self.deps.as_slice()
421+
}
422+
}
423+
416424
/// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression.
417425
pub fn aggregate_functional_dependencies(
418426
aggr_input_schema: &DFSchema,

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 162 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use datafusion_common::tree_node::{
4747
};
4848
use datafusion_common::{
4949
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
50-
DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
50+
DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies,
5151
OwnedTableReference, Result, ScalarValue, UnnestOptions,
5252
};
5353
// backwards compatibility
@@ -1030,7 +1030,13 @@ impl LogicalPlan {
10301030
pub fn max_rows(self: &LogicalPlan) -> Option<usize> {
10311031
match self {
10321032
LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(),
1033-
LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(),
1033+
LogicalPlan::Filter(filter) => {
1034+
if filter.is_scalar() {
1035+
Some(1)
1036+
} else {
1037+
filter.input.max_rows()
1038+
}
1039+
}
10341040
LogicalPlan::Window(Window { input, .. }) => input.max_rows(),
10351041
LogicalPlan::Aggregate(Aggregate {
10361042
input, group_expr, ..
@@ -1917,6 +1923,101 @@ impl Filter {
19171923

19181924
Ok(Self { predicate, input })
19191925
}
1926+
1927+
/// Is this filter guaranteed to return 0 or 1 row in a given instantiation?
1928+
///
1929+
/// This function will return `true` if its predicate contains a conjunction of
1930+
/// `col(a) = <expr>`, where its schema has a unique filter that is covered
1931+
/// by this conjunction.
1932+
///
1933+
/// For example, for the table:
1934+
/// ```sql
1935+
/// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER);
1936+
/// ```
1937+
/// `Filter(a = 2).is_scalar() == true`
1938+
/// , whereas
1939+
/// `Filter(b = 2).is_scalar() == false`
1940+
/// and
1941+
/// `Filter(a = 2 OR b = 2).is_scalar() == false`
1942+
fn is_scalar(&self) -> bool {
1943+
let schema = self.input.schema();
1944+
1945+
let functional_dependencies = self.input.schema().functional_dependencies();
1946+
let unique_keys = functional_dependencies.iter().filter(|dep| {
1947+
let nullable = dep.nullable
1948+
&& dep
1949+
.source_indices
1950+
.iter()
1951+
.any(|&source| schema.field(source).is_nullable());
1952+
!nullable
1953+
&& dep.mode == Dependency::Single
1954+
&& dep.target_indices.len() == schema.fields().len()
1955+
});
1956+
1957+
/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1958+
/// TODO: move this function from datafusion-optimizer::utils to datafusion-expr::utils
1959+
/// so we can re-use it
1960+
fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
1961+
split_conjunction_impl(expr, vec![])
1962+
}
1963+
1964+
fn split_conjunction_impl<'a>(
1965+
expr: &'a Expr,
1966+
mut exprs: Vec<&'a Expr>,
1967+
) -> Vec<&'a Expr> {
1968+
match expr {
1969+
Expr::BinaryExpr(BinaryExpr {
1970+
right,
1971+
op: Operator::And,
1972+
left,
1973+
}) => {
1974+
let exprs = split_conjunction_impl(left, exprs);
1975+
split_conjunction_impl(right, exprs)
1976+
}
1977+
Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
1978+
other => {
1979+
exprs.push(other);
1980+
exprs
1981+
}
1982+
}
1983+
}
1984+
1985+
let exprs = split_conjunction(&self.predicate);
1986+
let eq_pred_cols: HashSet<_> = exprs
1987+
.iter()
1988+
.filter_map(|expr| {
1989+
let Expr::BinaryExpr(BinaryExpr {
1990+
left,
1991+
op: Operator::Eq,
1992+
right,
1993+
}) = expr
1994+
else {
1995+
return None;
1996+
};
1997+
// This is a no-op filter expression
1998+
if left == right {
1999+
return None;
2000+
}
2001+
2002+
match (left.as_ref(), right.as_ref()) {
2003+
(Expr::Column(_), Expr::Column(_)) => None,
2004+
(Expr::Column(c), _) | (_, Expr::Column(c)) => {
2005+
Some(schema.index_of_column(c).unwrap())
2006+
}
2007+
_ => None,
2008+
}
2009+
})
2010+
.collect();
2011+
2012+
// If we have a functional dependence that is a subset of our predicate,
2013+
// this filter is scalar
2014+
for key in unique_keys {
2015+
if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) {
2016+
return true;
2017+
}
2018+
}
2019+
false
2020+
}
19202021
}
19212022

19222023
/// Window its input based on a set of window spec and window function (e.g. SUM or RANK)
@@ -2552,12 +2653,14 @@ pub struct Unnest {
25522653
#[cfg(test)]
25532654
mod tests {
25542655
use super::*;
2656+
use crate::builder::LogicalTableSource;
25552657
use crate::logical_plan::table_scan;
25562658
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
25572659
use arrow::datatypes::{DataType, Field, Schema};
25582660
use datafusion_common::tree_node::TreeNodeVisitor;
2559-
use datafusion_common::{not_impl_err, DFSchema, TableReference};
2661+
use datafusion_common::{not_impl_err, Constraint, DFSchema, TableReference};
25602662
use std::collections::HashMap;
2663+
use std::sync::Arc;
25612664

25622665
fn employee_schema() -> Schema {
25632666
Schema::new(vec![
@@ -3052,4 +3155,60 @@ digraph {
30523155
.unwrap()
30533156
.is_nullable());
30543157
}
3158+
#[test]
3159+
fn test_filter_is_scalar() {
3160+
// test empty placeholder
3161+
let schema =
3162+
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
3163+
3164+
let source = Arc::new(LogicalTableSource::new(schema));
3165+
let schema = Arc::new(
3166+
DFSchema::try_from_qualified_schema(
3167+
TableReference::bare("tab"),
3168+
&source.schema(),
3169+
)
3170+
.unwrap(),
3171+
);
3172+
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
3173+
table_name: TableReference::bare("tab"),
3174+
source: source.clone(),
3175+
projection: None,
3176+
projected_schema: schema.clone(),
3177+
filters: vec![],
3178+
fetch: None,
3179+
}));
3180+
let col = schema.field(0).qualified_column();
3181+
3182+
let filter = Filter::try_new(
3183+
Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
3184+
scan,
3185+
)
3186+
.unwrap();
3187+
assert!(!filter.is_scalar());
3188+
let unique_schema =
3189+
Arc::new(schema.as_ref().clone().with_functional_dependencies(
3190+
FunctionalDependencies::new_from_constraints(
3191+
Some(&Constraints::new_unverified(vec![Constraint::Unique(
3192+
vec![0],
3193+
)])),
3194+
1,
3195+
),
3196+
));
3197+
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
3198+
table_name: TableReference::bare("tab"),
3199+
source,
3200+
projection: None,
3201+
projected_schema: unique_schema.clone(),
3202+
filters: vec![],
3203+
fetch: None,
3204+
}));
3205+
let col = schema.field(0).qualified_column();
3206+
3207+
let filter = Filter::try_new(
3208+
Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
3209+
scan,
3210+
)
3211+
.unwrap();
3212+
assert!(filter.is_scalar());
3213+
}
30553214
}

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES
4949
(44, 'x', 3),
5050
(55, 'w', 3);
5151

52+
statement ok
53+
CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES
54+
(11, 'e', 3),
55+
(22, 'f', 1),
56+
(44, 'g', 3),
57+
(55, 'h', 3);
58+
5259
statement ok
5360
CREATE EXTERNAL TABLE IF NOT EXISTS customer (
5461
c_custkey BIGINT,
@@ -419,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2
419426
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row
420427
SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1
421428

429+
#non_aggregated_correlated_scalar_subquery_unique
430+
query II rowsort
431+
SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1
432+
----
433+
11 3
434+
22 1
435+
33 NULL
436+
44 3
437+
438+
439+
#non_aggregated_correlated_scalar_subquery
422440
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row
423441
SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1
424442

0 commit comments

Comments
 (0)