Skip to content

Commit

Permalink
[feature](Nereids) add predicates push down on all join type (apache#…
Browse files Browse the repository at this point in the history
…12571)

* [feature](Nereids) add predicates push down on all join type
  • Loading branch information
morrySnow authored Sep 15, 2022
1 parent 5b6d48e commit 858e823
Show file tree
Hide file tree
Showing 12 changed files with 594 additions and 292 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
import org.apache.doris.nereids.rules.mv.SelectRollup;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
Expand All @@ -27,14 +28,8 @@
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicateThroughJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -64,15 +59,9 @@ public RewriteJob(CascadesContext cascadesContext) {
.add(topDownBatch(ImmutableList.of(new ExpressionNormalization())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
.add(topDownBatch(ImmutableList.of(new PushPredicateThroughJoin(),
new PushdownProjectThroughLimit(),
new PushdownFilterThroughProject(),
new MergeConsecutiveProjects(),
new MergeConsecutiveFilters(),
new MergeConsecutiveLimits())))
.add(topDownBatch(RuleSet.PUSH_DOWN_JOIN_CONDITION_RULES))
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
.add(topDownBatch(ImmutableList.of(new AggregateDisassemble())))
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))
.add(topDownBatch(ImmutableList.of(new EliminateLimit())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@
import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject;
import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort;
import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveLimits;
import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects;
import org.apache.doris.nereids.rules.rewrite.logical.PushDownJoinOtherCondition;
import org.apache.doris.nereids.rules.rewrite.logical.PushPredicatesThroughJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownProjectThroughLimit;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
Expand All @@ -55,9 +59,14 @@ public class RuleSet {
.add(new MergeConsecutiveProjects())
.build();

public static final List<Rule> REWRITE_RULES = planRuleFactories()
.add(new AggregateDisassemble())
.build();
public static final List<RuleFactory> PUSH_DOWN_JOIN_CONDITION_RULES = ImmutableList.of(
new PushDownJoinOtherCondition(),
new PushPredicatesThroughJoin(),
new PushdownProjectThroughLimit(),
new PushdownFilterThroughProject(),
new MergeConsecutiveProjects(),
new MergeConsecutiveFilters(),
new MergeConsecutiveLimits());

public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories()
.add(new LogicalAggToPhysicalHashAgg())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public enum RuleType {
EXISTS_APPLY_TO_JOIN(RuleTypeClass.REWRITE),
// predicate push down rules
PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_DOWN_JOIN_OTHER_CONDITION(RuleTypeClass.REWRITE),
PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE),
// column prune rules,
COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite.logical;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Set;

/**
* Push the other join conditions in LogicalJoin to children.
*/
public class PushDownJoinOtherCondition extends OneRewriteRuleFactory {
private static final ImmutableList<JoinType> PUSH_DOWN_LEFT_VALID_TYPE = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_SEMI_JOIN,
JoinType.RIGHT_OUTER_JOIN,
JoinType.RIGHT_ANTI_JOIN,
JoinType.RIGHT_SEMI_JOIN,
JoinType.CROSS_JOIN
);

private static final ImmutableList<JoinType> PUSH_DOWN_RIGHT_VALID_TYPE = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_OUTER_JOIN,
JoinType.LEFT_ANTI_JOIN,
JoinType.LEFT_SEMI_JOIN,
JoinType.RIGHT_SEMI_JOIN,
JoinType.CROSS_JOIN
);

@Override
public Rule build() {
return logicalJoin().then(join -> {
if (!join.getOtherJoinCondition().isPresent()) {
return null;
}
List<Expression> otherConjuncts = ExpressionUtils.extractConjunction(join.getOtherJoinCondition().get());
List<Expression> leftConjuncts = Lists.newArrayList();
List<Expression> rightConjuncts = Lists.newArrayList();

for (Expression otherConjunct : otherConjuncts) {
if (PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType())
&& allCoveredBy(otherConjunct, join.left().getOutputSet())) {
leftConjuncts.add(otherConjunct);
}
if (PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType())
&& allCoveredBy(otherConjunct, join.right().getOutputSet())) {
rightConjuncts.add(otherConjunct);
}
}

if (leftConjuncts.isEmpty() && rightConjuncts.isEmpty()) {
return null;
}

otherConjuncts.removeAll(leftConjuncts);
otherConjuncts.removeAll(rightConjuncts);

Plan left = PlanUtils.filterOrSelf(leftConjuncts, join.left());
Plan right = PlanUtils.filterOrSelf(rightConjuncts, join.right());

return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
ExpressionUtils.optionalAnd(otherConjuncts), left, right);

}).toRule(RuleType.PUSH_DOWN_JOIN_OTHER_CONDITION);
}

private boolean allCoveredBy(Expression predicate, Set<Slot> inputSlotSet) {
return inputSlotSet.containsAll(predicate.getInputSlots());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,103 +21,130 @@
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import java.util.List;
import java.util.Objects;
import java.util.Set;

/**
* Push the predicate in the LogicalFilter or LogicalJoin to the join children.
* todo: Now, only support eq on condition for inner join, support other case later
* Push the predicate in the LogicalFilter to the join children.
*/
public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
public class PushPredicatesThroughJoin extends OneRewriteRuleFactory {

private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_LEFT = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.LEFT_OUTER_JOIN,
JoinType.LEFT_SEMI_JOIN,
JoinType.LEFT_ANTI_JOIN,
JoinType.CROSS_JOIN
);

private static final ImmutableList<JoinType> COULD_PUSH_THROUGH_RIGHT = ImmutableList.of(
JoinType.INNER_JOIN,
JoinType.RIGHT_OUTER_JOIN,
JoinType.RIGHT_SEMI_JOIN,
JoinType.RIGHT_ANTI_JOIN,
JoinType.CROSS_JOIN
);

private static final ImmutableList<JoinType> COULD_PUSH_EQUAL_TO = ImmutableList.of(
JoinType.INNER_JOIN
);

/*
* For example:
* select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2
* select a.k1, b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5
* where a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2
*
* Logical plan tree:
* project
* |
* filter (a.k1 > 1 and b.k1 > 2)
* filter (a.k1 > 1 and b.k1 > 2 and a.k2 > b.k2)
* |
* join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5)
* / \
* scan scan
* transformed:
* project
* |
* join (a.k1 = b.k1)
* filter(a.k2 > b.k2)
* |
* join (otherConditions: a.k1 = b.k1)
* / \
* filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5)
* filter(a.k1 > 1 and a.k2 > 2) filter(b.k1 > 2 and b.k2 > 5)
* | |
* scan scan
*/
@Override
public Rule build() {
return logicalFilter(innerLogicalJoin()).then(filter -> {
return logicalFilter(logicalJoin()).then(filter -> {

LogicalJoin<GroupPlan, GroupPlan> join = filter.child();

Expression wherePredicates = filter.getPredicates();
Expression onPredicates = join.getOtherJoinCondition().orElse(BooleanLiteral.TRUE);
Expression filterPredicates = filter.getPredicates();

List<Expression> otherConditions = Lists.newArrayList();
List<Expression> eqConditions = Lists.newArrayList();
List<Expression> filterConditions = Lists.newArrayList();
List<Expression> joinConditions = Lists.newArrayList();

Set<Slot> leftInput = join.left().getOutputSet();
Set<Slot> rightInput = join.right().getOutputSet();

ExpressionUtils.extractConjunction(ExpressionUtils.and(onPredicates, wherePredicates))
ExpressionUtils.extractConjunction(filterPredicates)
.forEach(predicate -> {
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) {
eqConditions.add(predicate);
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))
&& COULD_PUSH_EQUAL_TO.contains(join.getJoinType())) {
joinConditions.add(predicate);
} else {
otherConditions.add(predicate);
filterConditions.add(predicate);
}
});

List<Expression> leftPredicates = Lists.newArrayList();
List<Expression> rightPredicates = Lists.newArrayList();

for (Expression p : otherConditions) {
for (Expression p : filterConditions) {
Set<Slot> slots = p.getInputSlots();
if (slots.isEmpty()) {
leftPredicates.add(p);
rightPredicates.add(p);
continue;
}
if (leftInput.containsAll(slots)) {
if (leftInput.containsAll(slots) && COULD_PUSH_THROUGH_LEFT.contains(join.getJoinType())) {
leftPredicates.add(p);
}
if (rightInput.containsAll(slots)) {
if (rightInput.containsAll(slots) && COULD_PUSH_THROUGH_RIGHT.contains(join.getJoinType())) {
rightPredicates.add(p);
}
}

otherConditions.removeAll(leftPredicates);
otherConditions.removeAll(rightPredicates);
otherConditions.addAll(eqConditions);
filterConditions.removeAll(leftPredicates);
filterConditions.removeAll(rightPredicates);
join.getOtherJoinCondition().map(joinConditions::add);

return pushDownPredicate(join, otherConditions, leftPredicates, rightPredicates);
return PlanUtils.filterOrSelf(filterConditions,
pushDownPredicate(join, joinConditions, leftPredicates, rightPredicates));
}).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN);
}

private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> joinPlan,
private Plan pushDownPredicate(LogicalJoin<GroupPlan, GroupPlan> join,
List<Expression> joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) {
// todo expr should optimize again using expr rewrite
Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, joinPlan.left());
Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, joinPlan.right());
Plan leftPlan = PlanUtils.filterOrSelf(leftPredicates, join.left());
Plan rightPlan = PlanUtils.filterOrSelf(rightPredicates, join.right());

return new LogicalJoin<>(joinPlan.getJoinType(), joinPlan.getHashJoinConjuncts(),
return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(),
ExpressionUtils.optionalAnd(joinConditions), leftPlan, rightPlan);
}

Expand All @@ -128,13 +155,13 @@ private Expression getJoinCondition(Expression predicate, Set<Slot> leftOutputs,

ComparisonPredicate comparison = (ComparisonPredicate) predicate;

Set<Slot> leftSlots = comparison.left().getInputSlots();
Set<Slot> rightSlots = comparison.right().getInputSlots();

if (!(leftSlots.size() >= 1 && rightSlots.size() >= 1)) {
if (!(comparison instanceof EqualTo)) {
return null;
}

Set<Slot> leftSlots = comparison.left().getInputSlots();
Set<Slot> rightSlots = comparison.right().getInputSlots();

if ((leftOutputs.containsAll(leftSlots) && rightOutputs.containsAll(rightSlots))
|| (leftOutputs.containsAll(rightSlots) && rightOutputs.containsAll(leftSlots))) {
return predicate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
class FindHashConditionForJoinTest {
@Test
public void testFindHashCondition() {
Plan student = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.student, ImmutableList.of(""));
Plan score = new LogicalOlapScan(PlanConstructor.getNextId(), PlanConstructor.score, ImmutableList.of(""));
Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of(""));
Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of(""));

Slot studentId = student.getOutput().get(0);
Slot gender = student.getOutput().get(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
Expand All @@ -48,8 +49,8 @@
import java.util.stream.Collectors;

class LimitPushDownTest extends TestWithFeService implements PatternMatchSupported {
private Plan scanScore = new LogicalOlapScan(PlanConstructor.score);
private Plan scanStudent = new LogicalOlapScan(PlanConstructor.student);
private Plan scanScore = new LogicalOlapScan(new RelationId(0), PlanConstructor.score);
private Plan scanStudent = new LogicalOlapScan(new RelationId(1), PlanConstructor.student);

@Override
protected void runBeforeAll() throws Exception {
Expand Down Expand Up @@ -213,8 +214,8 @@ private Plan generatePlan(JoinType joinType, boolean hasProject) {
joinType,
joinConditions,
Optional.empty(),
new LogicalOlapScan(PlanConstructor.score),
new LogicalOlapScan(PlanConstructor.student)
new LogicalOlapScan(new RelationId(0), PlanConstructor.score),
new LogicalOlapScan(new RelationId(1), PlanConstructor.student)
);

if (hasProject) {
Expand Down
Loading

0 comments on commit 858e823

Please sign in to comment.