Skip to content

Commit 52fadb7

Browse files
committed
Detect when filters make subqueries scalar
1 parent 9619f02 commit 52fadb7

File tree

3 files changed

+160
-3
lines changed

3 files changed

+160
-3
lines changed

datafusion/common/src/functional_dependencies.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,14 @@ impl FunctionalDependencies {
413413
}
414414
}
415415

416+
impl Deref for FunctionalDependencies {
417+
type Target = [FunctionalDependence];
418+
419+
fn deref(&self) -> &Self::Target {
420+
self.deps.as_slice()
421+
}
422+
}
423+
416424
/// Calculates functional dependencies for aggregate output, when there is a GROUP BY expression.
417425
pub fn aggregate_functional_dependencies(
418426
aggr_input_schema: &DFSchema,

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 134 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use datafusion_common::tree_node::{
4747
};
4848
use datafusion_common::{
4949
aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints,
50-
DFField, DFSchema, DFSchemaRef, DataFusionError, FunctionalDependencies,
50+
DFField, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependencies,
5151
OwnedTableReference, Result, ScalarValue, UnnestOptions,
5252
};
5353
// backwards compatibility
@@ -1030,7 +1030,13 @@ impl LogicalPlan {
10301030
pub fn max_rows(self: &LogicalPlan) -> Option<usize> {
10311031
match self {
10321032
LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(),
1033-
LogicalPlan::Filter(Filter { input, .. }) => input.max_rows(),
1033+
LogicalPlan::Filter(filter) => {
1034+
if filter.is_scalar() {
1035+
Some(1)
1036+
} else {
1037+
filter.input.max_rows()
1038+
}
1039+
}
10341040
LogicalPlan::Window(Window { input, .. }) => input.max_rows(),
10351041
LogicalPlan::Aggregate(Aggregate {
10361042
input, group_expr, ..
@@ -1917,6 +1923,73 @@ impl Filter {
19171923

19181924
Ok(Self { predicate, input })
19191925
}
1926+
1927+
/// Is this filter guaranteed to return 0 or 1 row in a given instantiation?
1928+
///
1929+
/// This function will return `true` if its predicate contains a conjunction of
1930+
/// `col(a) = <expr>`, where its schema has a unique filter that is covered
1931+
/// by this conjunction.
1932+
///
1933+
/// For example, for the table:
1934+
/// ```sql
1935+
/// CREATE TABLE t (a INTEGER PRIMARY KEY, b INTEGER);
1936+
/// ```
1937+
/// `Filter(a = 2).is_scalar() == true`
1938+
/// , whereas
1939+
/// `Filter(b = 2).is_scalar() == false`
1940+
/// and
1941+
/// `Filter(a = 2 OR b = 2).is_scalar() == false`
1942+
fn is_scalar(&self) -> bool {
1943+
let schema = self.input.schema();
1944+
1945+
let functional_dependencies = self.input.schema().functional_dependencies();
1946+
let unique_keys = functional_dependencies.iter().filter(|dep| {
1947+
let nullable = dep.nullable
1948+
&& dep
1949+
.source_indices
1950+
.iter()
1951+
.any(|&source| schema.field(source).is_nullable());
1952+
!nullable
1953+
&& dep.mode == Dependency::Single
1954+
&& dep.target_indices.len() == schema.fields().len()
1955+
});
1956+
1957+
let exprs = split_conjunction(&self.predicate);
1958+
let eq_pred_cols: HashSet<_> = exprs
1959+
.iter()
1960+
.filter_map(|expr| {
1961+
let Expr::BinaryExpr(BinaryExpr {
1962+
left,
1963+
op: Operator::Eq,
1964+
right,
1965+
}) = expr
1966+
else {
1967+
return None;
1968+
};
1969+
// This is a no-op filter expression
1970+
if left == right {
1971+
return None;
1972+
}
1973+
1974+
match (left.as_ref(), right.as_ref()) {
1975+
(Expr::Column(_), Expr::Column(_)) => None,
1976+
(Expr::Column(c), _) | (_, Expr::Column(c)) => {
1977+
Some(schema.index_of_column(c).unwrap())
1978+
}
1979+
_ => None,
1980+
}
1981+
})
1982+
.collect();
1983+
1984+
// If we have a functional dependence that is a subset of our predicate,
1985+
// this filter is scalar
1986+
for key in unique_keys {
1987+
if key.source_indices.iter().all(|c| eq_pred_cols.contains(c)) {
1988+
return true;
1989+
}
1990+
}
1991+
false
1992+
}
19201993
}
19211994

19221995
/// Window its input based on a set of window spec and window function (e.g. SUM or RANK)
@@ -2552,12 +2625,14 @@ pub struct Unnest {
25522625
#[cfg(test)]
25532626
mod tests {
25542627
use super::*;
2628+
use crate::builder::LogicalTableSource;
25552629
use crate::logical_plan::table_scan;
25562630
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
25572631
use arrow::datatypes::{DataType, Field, Schema};
25582632
use datafusion_common::tree_node::TreeNodeVisitor;
2559-
use datafusion_common::{not_impl_err, DFSchema, TableReference};
2633+
use datafusion_common::{not_impl_err, Constraint, DFSchema, TableReference};
25602634
use std::collections::HashMap;
2635+
use std::sync::Arc;
25612636

25622637
fn employee_schema() -> Schema {
25632638
Schema::new(vec![
@@ -3052,4 +3127,60 @@ digraph {
30523127
.unwrap()
30533128
.is_nullable());
30543129
}
3130+
#[test]
3131+
fn test_filter_is_scalar() {
3132+
// test empty placeholder
3133+
let schema =
3134+
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
3135+
3136+
let source = Arc::new(LogicalTableSource::new(schema));
3137+
let schema = Arc::new(
3138+
DFSchema::try_from_qualified_schema(
3139+
TableReference::bare("tab"),
3140+
&source.schema(),
3141+
)
3142+
.unwrap(),
3143+
);
3144+
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
3145+
table_name: TableReference::bare("tab"),
3146+
source: source.clone(),
3147+
projection: None,
3148+
projected_schema: schema.clone(),
3149+
filters: vec![],
3150+
fetch: None,
3151+
}));
3152+
let col = schema.field(0).qualified_column();
3153+
3154+
let filter = Filter::try_new(
3155+
Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
3156+
scan,
3157+
)
3158+
.unwrap();
3159+
assert!(!filter.is_scalar());
3160+
let unique_schema =
3161+
Arc::new(schema.as_ref().clone().with_functional_dependencies(
3162+
FunctionalDependencies::new_from_constraints(
3163+
Some(&Constraints::new_unverified(vec![Constraint::Unique(
3164+
vec![0],
3165+
)])),
3166+
1,
3167+
),
3168+
));
3169+
let scan = Arc::new(LogicalPlan::TableScan(TableScan {
3170+
table_name: TableReference::bare("tab"),
3171+
source,
3172+
projection: None,
3173+
projected_schema: unique_schema.clone(),
3174+
filters: vec![],
3175+
fetch: None,
3176+
}));
3177+
let col = schema.field(0).qualified_column();
3178+
3179+
let filter = Filter::try_new(
3180+
Expr::Column(col).eq(Expr::Literal(ScalarValue::Int32(Some(1)))),
3181+
scan,
3182+
)
3183+
.unwrap();
3184+
assert!(filter.is_scalar());
3185+
}
30553186
}

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES
4949
(44, 'x', 3),
5050
(55, 'w', 3);
5151

52+
statement ok
53+
CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES
54+
(11, 'e', 3),
55+
(22, 'f', 1),
56+
(44, 'g', 3),
57+
(55, 'h', 3);
58+
5259
statement ok
5360
CREATE EXTERNAL TABLE IF NOT EXISTS customer (
5461
c_custkey BIGINT,
@@ -419,6 +426,17 @@ SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2
419426
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row
420427
SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int) as t2_int from t1
421428

429+
#non_aggregated_correlated_scalar_subquery_unique
430+
query II rowsort
431+
SELECT t1_id, (SELECT t3_int FROM t3 WHERE t3.t3_id = t1.t1_id) as t3_int from t1
432+
----
433+
11 3
434+
22 1
435+
33 NULL
436+
44 3
437+
438+
439+
#non_aggregated_correlated_scalar_subquery
422440
statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated scalar subquery must be aggregated to return at most one row
423441
SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1_int group by t2_int) as t2_int from t1
424442

0 commit comments

Comments
 (0)