Skip to content

Commit fbe82fb

Browse files
cloud-fanwangyum
authored andcommitted
[SPARK-36194][SQL][FOLLOWUP] Propagate distinct keys more precisely
### What changes were proposed in this pull request? This PR is a followup of #35779 , to propagate distinct keys more precisely in 2 cases: 1. For `LIMIT 1`, each output attribute is a distinct key, not the entire tuple. 2. For aggregate, we can still propagate distinct keys from child. ### Why are the changes needed? make the optimization cover more cases ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new tests Closes #36100 from cloud-fan/followup. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Yuming Wang <yumwang@ebay.com>
1 parent 47991b0 commit fbe82fb

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitor.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,27 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
5050
}.filter(_.nonEmpty)
5151
}
5252

53+
/**
54+
* Add a new ExpressionSet S into distinctKeys D.
55+
* To minimize the size of D:
56+
* 1. If there is a subset of S in D, return D.
57+
* 2. Otherwise, remove all the ExpressionSet containing S from D, and add the new one.
58+
*/
59+
private def addDistinctKey(
60+
keys: Set[ExpressionSet],
61+
newExpressionSet: ExpressionSet): Set[ExpressionSet] = {
62+
if (keys.exists(_.subsetOf(newExpressionSet))) {
63+
keys
64+
} else {
65+
keys.filterNot(s => newExpressionSet.subsetOf(s)) + newExpressionSet
66+
}
67+
}
68+
5369
override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet]
5470

5571
override def visitAggregate(p: Aggregate): Set[ExpressionSet] = {
5672
val groupingExps = ExpressionSet(p.groupingExpressions) // handle group by a, a
57-
projectDistinctKeys(Set(groupingExps), p.aggregateExpressions)
73+
projectDistinctKeys(addDistinctKey(p.child.distinctKeys, groupingExps), p.aggregateExpressions)
5874
}
5975

6076
override def visitDistinct(p: Distinct): Set[ExpressionSet] = Set(ExpressionSet(p.output))
@@ -70,7 +86,7 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
7086

7187
override def visitGlobalLimit(p: GlobalLimit): Set[ExpressionSet] = {
7288
p.maxRows match {
73-
case Some(value) if value <= 1 => Set(ExpressionSet(p.output))
89+
case Some(value) if value <= 1 => p.output.map(attr => ExpressionSet(Seq(attr))).toSet
7490
case _ => p.child.distinctKeys
7591
}
7692
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/DistinctKeyVisitorSuite.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class DistinctKeyVisitorSuite extends PlanTest {
6666
Set(ExpressionSet(Seq(a)), ExpressionSet(Seq(d.toAttribute))))
6767
checkDistinctAttributes(t1.groupBy(f.child, $"b")(f, $"b", sum($"c")),
6868
Set(ExpressionSet(Seq(f.toAttribute, b))))
69+
70+
// Aggregate should also propagate distinct keys from child
71+
checkDistinctAttributes(t1.limit(1).groupBy($"a", $"b")($"a", $"b"),
72+
Set(ExpressionSet(Seq(a)), ExpressionSet(Seq(b))))
6973
}
7074

7175
test("Distinct's distinct attributes") {
@@ -86,7 +90,8 @@ class DistinctKeyVisitorSuite extends PlanTest {
8690
test("Limit's distinct attributes") {
8791
checkDistinctAttributes(Distinct(t1).limit(10), Set(ExpressionSet(Seq(a, b, c))))
8892
checkDistinctAttributes(LocalLimit(10, Distinct(t1)), Set(ExpressionSet(Seq(a, b, c))))
89-
checkDistinctAttributes(t1.limit(1), Set(ExpressionSet(Seq(a, b, c))))
93+
checkDistinctAttributes(t1.limit(1),
94+
Set(ExpressionSet(Seq(a)), ExpressionSet(Seq(b)), ExpressionSet(Seq(c))))
9095
}
9196

9297
test("Intersect's distinct attributes") {

0 commit comments

Comments
 (0)