Skip to content

Commit 21fb7c5

Browse files
committed
Rewriting join condition to conjunctive normal form expression
1 parent f0e2fc3 commit 21fb7c5

File tree

3 files changed

+196
-2
lines changed

3 files changed

+196
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
118118
Batch("Infer Filters", Once,
119119
InferFiltersFromConstraints) ::
120120
Batch("Operator Optimization after Inferring Filters", fixedPoint,
121-
rulesWithoutInferFiltersFromConstraints: _*) :: Nil
121+
rulesWithoutInferFiltersFromConstraints: _*) ::
122+
Batch("Push predicate through join by conjunctive normal form", Once,
123+
PushPredicateThroughJoinByCNF) :: Nil
122124
}
123125

124126
val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) ::
@@ -1372,6 +1374,80 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
13721374
}
13731375
}
13741376

1377+
/**
1378+
* Rewriting join condition to conjunctive normal form expression so that we can push
1379+
* more predicate.
1380+
*/
1381+
object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper {
1382+
1383+
/**
1384+
* Rewrite pattern:
1385+
* 1. (a && b) || c --> (a || c) && (b || c)
1386+
* 2. a || (b && c) --> (a || b) && (a || c)
1387+
* 3. !(a || b) --> !a && !b
1388+
*/
1389+
private def rewriteToCNF(condition: Expression, depth: Int = 0): Expression = {
1390+
if (depth < SQLConf.get.maxRewritingCNFDepth) {
1391+
val nextDepth = depth + 1
1392+
condition match {
1393+
case Or(And(a, b), c) =>
1394+
And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth),
1395+
rewriteToCNF(Or(rewriteToCNF(b, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth))
1396+
case Or(a, And(b, c)) =>
1397+
And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)), nextDepth),
1398+
rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth))
1399+
case Not(Or(a, b)) =>
1400+
And(rewriteToCNF(Not(rewriteToCNF(a, nextDepth)), nextDepth),
1401+
rewriteToCNF(Not(rewriteToCNF(b, nextDepth)), nextDepth))
1402+
case And(a, b) =>
1403+
And(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth))
1404+
case other => other
1405+
}
1406+
} else {
1407+
condition
1408+
}
1409+
}
1410+
1411+
private def maybeWithFilter(joinCondition: Seq[Expression], plan: LogicalPlan) = {
1412+
(joinCondition.reduceLeftOption(And).reduceLeftOption(And), plan) match {
1413+
case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) =>
1414+
plan
1415+
case (Some(condition), _) =>
1416+
Filter(condition, plan)
1417+
case _ =>
1418+
plan
1419+
}
1420+
}
1421+
1422+
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
1423+
1424+
val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
1425+
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
1426+
1427+
val pushDownCandidates = splitConjunctivePredicates(rewriteToCNF(joinCondition))
1428+
.filter(_.deterministic)
1429+
val (leftEvaluateCondition, rest) =
1430+
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
1431+
val (rightEvaluateCondition, _) =
1432+
rest.partition(expr => expr.references.subsetOf(right.outputSet))
1433+
1434+
val newLeft = maybeWithFilter(leftEvaluateCondition, left)
1435+
val newRight = maybeWithFilter(rightEvaluateCondition, right)
1436+
1437+
joinType match {
1438+
case _: InnerLike | LeftSemi =>
1439+
Join(newLeft, newRight, joinType, Some(joinCondition), hint)
1440+
case RightOuter =>
1441+
Join(newLeft, right, RightOuter, Some(joinCondition), hint)
1442+
case LeftOuter | LeftAnti | ExistenceJoin(_) =>
1443+
Join(left, newRight, joinType, Some(joinCondition), hint)
1444+
case FullOuter => j
1445+
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
1446+
case UsingJoin(_, _) => sys.error("Untransformed Using join node")
1447+
}
1448+
}
1449+
}
1450+
13751451
/**
13761452
* Combines two adjacent [[Limit]] operators into one, merging the
13771453
* expressions into one single expression.

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,18 @@ object SQLConf {
544544
.booleanConf
545545
.createWithDefault(true)
546546

547+
val MAX_REWRITING_CNF_DEPTH =
548+
buildConf("spark.sql.maxRewritingCNFDepth")
549+
.internal()
550+
.doc("The maximum depth of rewriting a join condition to conjunctive normal form " +
551+
"expression. The deeper, the more predicate may be found, but the optimization time " +
552+
"will increase. The default is 6. By setting this value to 0 this feature can be disabled.")
553+
.version("3.1.0")
554+
.intConf
555+
.checkValue(_ >= 0,
556+
"The depth of the maximum rewriting conjunction normal form must be positive.")
557+
.createWithDefault(6)
558+
547559
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
548560
.internal()
549561
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
@@ -2845,6 +2857,8 @@ class SQLConf extends Serializable with Logging {
28452857

28462858
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
28472859

2860+
def maxRewritingCNFDepth: Int = getConf(MAX_REWRITING_CNF_DEPTH)
2861+
28482862
def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
28492863

28502864
def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ class FilterPushdownSuite extends PlanTest {
3939
PushPredicateThroughNonJoin,
4040
BooleanSimplification,
4141
PushPredicateThroughJoin,
42-
CollapseProject) :: Nil
42+
CollapseProject) ::
43+
Batch("PushPredicateThroughJoinByCNF", Once,
44+
PushPredicateThroughJoinByCNF) :: Nil
4345
}
4446

4547
val attrA = 'a.int
@@ -1230,4 +1232,106 @@ class FilterPushdownSuite extends PlanTest {
12301232

12311233
comparePlans(Optimize.execute(query.analyze), expected)
12321234
}
1235+
1236+
test("inner join: rewrite filter predicates to conjunctive normal form") {
1237+
val x = testRelation.subquery('x)
1238+
val y = testRelation.subquery('y)
1239+
1240+
val originalQuery = {
1241+
x.join(y)
1242+
.where(("x.b".attr === "y.b".attr)
1243+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))
1244+
}
1245+
1246+
val optimized = Optimize.execute(originalQuery.analyze)
1247+
val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x)
1248+
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
1249+
val correctAnswer =
1250+
left.join(right, condition = Some("x.b".attr === "y.b".attr
1251+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1252+
.analyze
1253+
1254+
comparePlans(optimized, correctAnswer)
1255+
}
1256+
1257+
test("inner join: rewrite join predicates to conjunctive normal form") {
1258+
val x = testRelation.subquery('x)
1259+
val y = testRelation.subquery('y)
1260+
1261+
val originalQuery = {
1262+
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
1263+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1264+
}
1265+
1266+
val optimized = Optimize.execute(originalQuery.analyze)
1267+
val left = testRelation.where('a > 3 || 'a > 1).subquery('x)
1268+
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
1269+
val correctAnswer =
1270+
left.join(right, condition = Some("x.b".attr === "y.b".attr
1271+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1272+
.analyze
1273+
1274+
comparePlans(optimized, correctAnswer)
1275+
}
1276+
1277+
test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") {
1278+
val x = testRelation.subquery('x)
1279+
val y = testRelation.subquery('y)
1280+
1281+
val originalQuery = {
1282+
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
1283+
&& Not(("x.a".attr > 3)
1284+
&& ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1285+
}
1286+
1287+
val optimized = Optimize.execute(originalQuery.analyze)
1288+
val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x)
1289+
val right = testRelation.subquery('y)
1290+
val correctAnswer =
1291+
left.join(right, condition = Some("x.b".attr === "y.b".attr
1292+
&& (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13)))
1293+
&& (("x.a".attr <= 1) || ("y.a".attr <= 11))))
1294+
.analyze
1295+
comparePlans(optimized, correctAnswer)
1296+
}
1297+
1298+
test("left join: rewrite join predicates to conjunctive normal form") {
1299+
val x = testRelation.subquery('x)
1300+
val y = testRelation.subquery('y)
1301+
1302+
val originalQuery = {
1303+
x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr)
1304+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1305+
}
1306+
1307+
val optimized = Optimize.execute(originalQuery.analyze)
1308+
val left = testRelation.subquery('x)
1309+
val right = testRelation.where('a > 13 || 'a > 11).subquery('y)
1310+
val correctAnswer =
1311+
left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr
1312+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1313+
.analyze
1314+
1315+
comparePlans(optimized, correctAnswer)
1316+
}
1317+
1318+
test("right join: rewrite join predicates to conjunctive normal form") {
1319+
val x = testRelation.subquery('x)
1320+
val y = testRelation.subquery('y)
1321+
1322+
val originalQuery = {
1323+
x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr)
1324+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1325+
}
1326+
1327+
val optimized = Optimize.execute(originalQuery.analyze)
1328+
val left = testRelation.where('a > 3 || 'a > 1).subquery('x)
1329+
val right = testRelation.subquery('y)
1330+
val correctAnswer =
1331+
left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr
1332+
&& (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))))
1333+
.analyze
1334+
1335+
comparePlans(optimized, correctAnswer)
1336+
}
12331337
}

0 commit comments

Comments
 (0)