Skip to content

[SPARK-12032] [SQL] Re-order inner joins to do join with conditions first #10073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet

import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.LeftOuter
import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
Expand All @@ -44,6 +42,7 @@ object DefaultOptimizer extends Optimizer {
// Operator push down
SetOperationPushDown,
SamplePushDown,
ReorderJoin,
PushPredicateThroughJoin,
PushPredicateThroughProject,
PushPredicateThroughGenerate,
Expand Down Expand Up @@ -711,6 +710,53 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
}
}

/**
* Reorder the joins and push all the conditions into join, so that the bottom ones have at least
* one condition.
*
* The order of joins will not be changed if all of them already have at least one condition.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add to this comment what makes this rule stable? It's not obvious from reading the code.

object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment if it is the same as the object comment or augment this with more detail.

Can you comment what the input arguments are? What is input? The least common ancestor of joins? Similar for conditions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

* Join a list of plans together and push down the conditions into them.
*
* The joined plan are picked from left to right, prefer those has at least one join condition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we generate a new identical tree each time this is run? Does this mess up the optimizer termination logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transform will try to use original tree if a rule returns a identical tree, so don't need to have this optimization manually. @marmbrus is it right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be okay. We check reference equality first but then also check equals. As long as its not going to oscillate between plans it should terminate.

*
* @param input a list of LogicalPlans to join.
* @param conditions a list of condition for join.
*/
def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = {
assert(input.size >= 2)
if (input.size == 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert(input.size > 2)? then we don't need this if branch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a branch to terminate this recursive call anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry I missed it

Join(input(0), input(1), Inner, conditions.reduceLeftOption(And))
} else {
val left :: rest = input.toList
// find out the first join that have at least one join condition
val conditionalJoin = rest.find { plan =>
val refs = left.outputSet ++ plan.outputSet
conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conditions.exists(_.references.intersect(refs).size == 2) will it be more simple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not right, the two reference could both came from left or right.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks!

.exists(_.references.subsetOf(refs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the implementation of canEvaluate:

protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
    expr.references.subsetOf(plan.outputSet)

What if condition is 'a === 'b while 'a is in left and 'b is in plan? We will return false here when the plan is actualy qualified.

So I think conditions.exists(_.references.subsetOf(refs)) is enough here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in this case, the expression can not be evaluate by only left or plan, can only be evaluated after join.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why do we need to ensure the condition can only be evaluated by only left or plan? Is a join condition allowed to reference both left and right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's filterNot, an expression is a join condition only when the references come from both left and right, so we should exclude those that have all the reference only come from left or right.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, thanks for the explanation!

}
// pick the next one if no condition left
val right = conditionalJoin.getOrElse(rest.head)

val joinedRefs = left.outputSet ++ right.outputSet
val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs))
val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And))

// should not have reference to same logical plan
createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this eq is safe, even in the presence of shared subtrees from self-joins (since the analyzer will rewrite the tree to avoid conflicting expression ids), but it might be slightly clearer to use partition above instead of find if thats not too much work.

}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ ExtractFiltersAndInnerJoins(input, conditions)
if input.size > 2 && conditions.nonEmpty =>
createOrderedJoin(input, conditions)
}
}

/**
* Pushes down [[Filter]] operators where the `condition` can be
* evaluated using only the attributes of the left or right side of a join. Other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef

/**
* A pattern that matches any number of project or filter operations on top of another relational
Expand Down Expand Up @@ -132,6 +131,45 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
}
}

/**
* A pattern that collects the filter and inner joins.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it much more work to extract all the filters? For example if there is a filter after the inner join of input and plan 1. We'd ideally use this for predicate progation as well.

For example

select * from t1 join t2 on t1.key = t2.key and t1.key = 5. If we collected all the filters, this could be used to infer t2.key = 5 and push that down to t2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

*
* Filter
* |
* inner Join
* / \ ----> (Seq(plan0, plan1, plan2), conditions)
* Filter plan2
* |
* inner join
* / \
* plan0 plan1
*
* Note: This pattern currently only works for left-deep trees.
*/
object ExtractFiltersAndInnerJoins extends PredicateHelper {

// flatten all inner joins, which are next to each other
def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match {
case Join(left, right, Inner, cond) =>
val (plans, conditions) = flattenJoin(left)
(plans ++ Seq(right), conditions ++ cond.toSeq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be splitConjunctivePredicates(conditions) ++ cond.toSeq?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conditions is already a list of Expression (splitted)


case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Join(left, right, Inner, joinCondition) => Join(left, right, Inner, None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just j @ Join(_, _, Inner, _)), the left, right and joinCondition are not used.

val (plans, conditions) = flattenJoin(j)
(plans, conditions ++ splitConjunctivePredicates(filterCondition))

case _ => (Seq(plan), Seq())
}

def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match {
case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) =>
Some(flattenJoin(f))
case j @ Join(_, _, Inner, _) =>
Some(flattenJoin(j))
case _ => None
}
}

/**
* A pattern that collects all adjacent unions and returns their children as a Seq.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class JoinOrderSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Filter Pushdown", Once,
CombineFilters,
PushPredicateThroughProject,
BooleanSimplification,
ReorderJoin,
PushPredicateThroughJoin,
PushPredicateThroughGenerate,
PushPredicateThroughAggregate,
ColumnPruning,
ProjectCollapsing) :: Nil

}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelation1 = LocalRelation('d.int)

test("extract filters and joins") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)
val z = testRelation.subquery('z)

def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) {
assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected)
}

testExtract(x, None)
testExtract(x.where("x.b".attr === 1), None)
testExtract(x.join(y), Some(Seq(x, y), Seq()))
testExtract(x.join(y, condition = Some("x.b".attr === "y.d".attr)),
Some(Seq(x, y), Seq("x.b".attr === "y.d".attr)))
testExtract(x.join(y).where("x.b".attr === "y.d".attr),
Some(Seq(x, y), Seq("x.b".attr === "y.d".attr)))
testExtract(x.join(y).join(z), Some(Seq(x, y, z), Seq()))
testExtract(x.join(y).where("x.b".attr === "y.d".attr).join(z),
Some(Seq(x, y, z), Seq("x.b".attr === "y.d".attr)))
testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq()))
testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr),
Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr)))
}

test("reorder inner joins") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)
val z = testRelation.subquery('z)

val originalQuery = {
x.join(y).join(z)
.where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr))
}

val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
x.join(z, condition = Some("x.b".attr === "z.b".attr))
.join(y, condition = Some("y.d".attr === "z.a".attr))
.analyze

comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer))
}
}