Skip to content

Commit

Permalink
[enhancement](Nereids): optimize GroupExpressionMatching (#26196)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener authored Nov 1, 2023
1 parent 502f577 commit 6010be8
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public ApplyRuleJob(GroupExpression groupExpression, Rule rule, JobContext conte
}

@Override
public void execute() throws AnalysisException {
public final void execute() throws AnalysisException {
if (groupExpression.hasApplied(rule)
|| groupExpression.isUnused()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public GroupExpressionIterator iterator() {
public static class GroupExpressionIterator implements Iterator<Plan> {
private final List<Plan> results = Lists.newArrayList();
private int resultIndex = 0;
private int resultsSize;

/**
* Constructor.
Expand Down Expand Up @@ -103,7 +104,7 @@ public GroupExpressionIterator(Pattern<Plan> pattern, GroupExpression groupExpre
// matching children group, one List<Plan> per child
// first dimension is every child group's plan
// second dimension is all matched plan in one group
List<List<Plan>> childrenPlans = Lists.newArrayListWithCapacity(childrenGroupArity);
List<Plan>[] childrenPlans = new List[childrenGroupArity];
for (int i = 0; i < childrenGroupArity; ++i) {
Group childGroup = groupExpression.child(i);
List<Plan> childrenPlan = matchingChildGroup(pattern, childGroup, i);
Expand All @@ -116,7 +117,7 @@ public GroupExpressionIterator(Pattern<Plan> pattern, GroupExpression groupExpre
return;
}
}
childrenPlans.add(childrenPlan);
childrenPlans[i] = childrenPlan;
}
assembleAllCombinationPlanTree(root, pattern, groupExpression, childrenPlans);
} else if (patternArity == 1 && (pattern.hasMultiChild() || pattern.hasMultiGroupChild())) {
Expand All @@ -127,6 +128,7 @@ public GroupExpressionIterator(Pattern<Plan> pattern, GroupExpression groupExpre
results.add(root);
}
}
this.resultsSize = results.size();
}

private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
Expand Down Expand Up @@ -154,38 +156,35 @@ private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
}

private void assembleAllCombinationPlanTree(Plan root, Pattern<Plan> rootPattern,
GroupExpression groupExpression,
List<List<Plan>> childrenPlans) {
int[] childrenPlanIndex = new int[childrenPlans.size()];
GroupExpression groupExpression, List<Plan>[] childrenPlans) {
int childrenPlansSize = childrenPlans.length;
int[] childrenPlanIndex = new int[childrenPlansSize];
int offset = 0;
LogicalProperties logicalProperties = groupExpression.getOwnerGroup().getLogicalProperties();

// assemble all combination of plan tree by current root plan and children plan
while (offset < childrenPlans.size()) {
ImmutableList.Builder<Plan> childrenBuilder =
ImmutableList.builderWithExpectedSize(childrenPlans.size());
for (int i = 0; i < childrenPlans.size(); i++) {
childrenBuilder.add(childrenPlans.get(i).get(childrenPlanIndex[i]));
Optional<GroupExpression> groupExprOption = Optional.of(groupExpression);
Optional<LogicalProperties> logicalPropOption = Optional.of(logicalProperties);
while (offset < childrenPlansSize) {
ImmutableList.Builder<Plan> childrenBuilder = ImmutableList.builderWithExpectedSize(childrenPlansSize);
for (int i = 0; i < childrenPlansSize; i++) {
childrenBuilder.add(childrenPlans[i].get(childrenPlanIndex[i]));
}
List<Plan> children = childrenBuilder.build();

// assemble children: replace GroupPlan to real plan,
// withChildren will erase groupExpression, so we must
// withGroupExpression too.
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(Optional.of(groupExpression),
Optional.of(logicalProperties), children);
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(groupExprOption,
logicalPropOption, children);
if (rootPattern.matchPredicates(rootWithChildren)) {
results.add(rootWithChildren);
}
offset = 0;
while (true) {
for (offset = 0; offset < childrenPlansSize; offset++) {
childrenPlanIndex[offset]++;
if (childrenPlanIndex[offset] == childrenPlans.get(offset).size()) {
if (childrenPlanIndex[offset] == childrenPlans[offset].size()) {
// Reset the index when it reaches the size of the current child plan list
childrenPlanIndex[offset] = 0;
offset++;
if (offset == childrenPlans.size()) {
break;
}
} else {
break;
}
Expand All @@ -195,7 +194,7 @@ private void assembleAllCombinationPlanTree(Plan root, Pattern<Plan> rootPattern

@Override
public boolean hasNext() {
return resultIndex < results.size();
return resultIndex < resultsSize;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@

package org.apache.doris.nereids.trees;

import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.planner.PlanNodeId;

import com.google.common.collect.ImmutableList;

import java.util.List;
Expand All @@ -33,7 +29,6 @@
*/
public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
implements TreeNode<NODE_TYPE> {
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
protected final List<NODE_TYPE> children;
// TODO: Maybe we should use a GroupPlan to avoid TreeNode hold the GroupExpression.
// https://github.com/apache/doris/pull/9807#discussion_r884829067
Expand All @@ -59,12 +54,4 @@ public List<NODE_TYPE> children() {
public int arity() {
return children.size();
}

/**
* used for PhysicalPlanTranslator only
* @return PlanNodeId
*/
public PlanNodeId translatePlanNodeId() {
return id.toPlanNodeId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@
import org.apache.doris.nereids.util.Utils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Abstract class for all Expression in Nereids.
Expand Down Expand Up @@ -247,8 +247,19 @@ public final Set<Slot> getInputSlots() {
return collect(Slot.class::isInstance);
}

/**
* Get all the input slot ids of the expression.
* <p>
* Note that the input slots of subquery's inner plan is not included.
*/
public final Set<ExprId> getInputSlotExprIds() {
return getInputSlots().stream().map(NamedExpression::getExprId).collect(Collectors.toSet());
ImmutableSet.Builder<ExprId> result = ImmutableSet.builder();
foreach(node -> {
if (node instanceof Slot) {
result.add(((Slot) node).getExprId());
}
});
return result.build();
}

public boolean isLiteral() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.util.MutableState;
import org.apache.doris.nereids.util.MutableState.EmptyMutableState;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.planner.PlanNodeId;
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Supplier;
Expand All @@ -45,6 +47,7 @@
*/
public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Plan {
public static final String FRAGMENT_ID = "fragment";
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();

protected final Statistics statistics;
protected final PlanType type;
Expand Down Expand Up @@ -168,4 +171,12 @@ public Optional<Object> getMutableState(String key) {
public void setMutableState(String key, Object state) {
this.mutableState = this.mutableState.set(key, state);
}

/**
* used for PhysicalPlanTranslator only
* @return PlanNodeId
*/
public PlanNodeId translatePlanNodeId() {
return id.toPlanNodeId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
import org.apache.doris.thrift.TRuntimeFilterType;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -113,22 +113,25 @@ private PhysicalHashJoin(
* Return pair of left used slots and right used slots.
*/
public Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
List<ExprId> exprIds1 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
List<ExprId> exprIds2 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
int size = hashJoinConjuncts.size();

List<ExprId> exprIds1 = new ArrayList<>(size);
List<ExprId> exprIds2 = new ArrayList<>(size);

Set<ExprId> leftExprIds = left().getOutputExprIdSet();
Set<ExprId> rightExprIds = right().getOutputExprIdSet();

for (Expression expr : hashJoinConjuncts) {
expr.getInputSlotExprIds().forEach(exprId -> {
for (ExprId exprId : expr.getInputSlotExprIds()) {
if (leftExprIds.contains(exprId)) {
exprIds1.add(exprId);
} else if (rightExprIds.contains(exprId)) {
exprIds2.add(exprId);
} else {
throw new RuntimeException("Could not generate valid equal on clause slot pairs for join");
throw new RuntimeException("Invalid ExprId found: " + exprId
+ ". Cannot generate valid equal on clause slot pairs for join.");
}
});
}
}
return Pair.of(exprIds1, exprIds2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -66,9 +67,10 @@ public static boolean couldBroadcast(Join join) {
* check if the row count of the left child in the broadcast join is less than a threshold value.
*/
public static boolean checkBroadcastJoinStats(PhysicalHashJoin<? extends Plan, ? extends Plan> join) {
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
double memLimit = sessionVariable.getMaxExecMemByte();
double rowsLimit = sessionVariable.getBroadcastRowCountLimit();
double brMemlimit = sessionVariable.getBroadcastHashtableMemLimitPercentage();
double datasize = join.getGroupExpression().get().child(1).getStatistics().computeSize();
double rowCount = join.getGroupExpression().get().child(1).getStatistics().getRowCount();
return rowCount <= rowsLimit && datasize <= memLimit * brMemlimit;
Expand Down Expand Up @@ -114,12 +116,12 @@ boolean isCoveredByRightSlots(ExprId slot) {
* @return true if the equal can be used as hash join condition
*/
public boolean isHashJoinCondition(EqualTo equalTo) {
Set<Slot> equalLeft = equalTo.left().collect(Slot.class::isInstance);
Set<Slot> equalLeft = equalTo.left().getInputSlots();
if (equalLeft.isEmpty()) {
return false;
}

Set<Slot> equalRight = equalTo.right().collect(Slot.class::isInstance);
Set<Slot> equalRight = equalTo.right().getInputSlots();
if (equalRight.isEmpty()) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static Plan filterOrSelf(Set<Expression> predicates, Plan plan) {
* normalize comparison predicate on a binary plan to its two sides are corresponding to the child's output.
*/
public static ComparisonPredicate maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
Set<Slot> slots = expression.left().collect(Slot.class::isInstance);
Set<Slot> slots = expression.left().getInputSlots();
Set<Slot> leftSlots = left.getOutputSet();
Set<Slot> buffer = Sets.newHashSet(slots);
buffer.removeAll(leftSlots);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@

import java.util.Iterator;

public class GroupExpressionMatchingTest {
class GroupExpressionMatchingTest {

@Test
public void testLeafNode() {
void testLeafNode() {
Pattern pattern = new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION);

Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test")));
Expand All @@ -61,7 +61,7 @@ public void testLeafNode() {
}

@Test
public void testDepth2() {
void testDepth2() {
Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT,
new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION));

Expand Down Expand Up @@ -93,7 +93,7 @@ public void testDepth2() {
}

@Test
public void testDepth2WithGroup() {
void testDepth2WithGroup() {
Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT, Pattern.GROUP);

Plan leaf = new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"));
Expand All @@ -119,7 +119,7 @@ public void testDepth2WithGroup() {
}

@Test
public void testLeafAny() {
void testLeafAny() {
Pattern pattern = Pattern.ANY;

Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test")));
Expand All @@ -135,7 +135,7 @@ public void testLeafAny() {
}

@Test
public void testAnyWithChild() {
void testAnyWithChild() {
Plan root = new LogicalProject(
ImmutableList.of(new SlotReference("name", StringType.INSTANCE, true,
ImmutableList.of("test"))),
Expand All @@ -159,7 +159,7 @@ public void testAnyWithChild() {
}

@Test
public void testInnerLogicalJoinMatch() {
void testInnerLogicalJoinMatch() {
Plan root = new LogicalJoin(JoinType.INNER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
Expand All @@ -181,7 +181,7 @@ public void testInnerLogicalJoinMatch() {
}

@Test
public void testInnerLogicalJoinMismatch() {
void testInnerLogicalJoinMismatch() {
Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
Expand All @@ -198,7 +198,7 @@ public void testInnerLogicalJoinMismatch() {
}

@Test
public void testTopMatchButChildrenNotMatch() {
void testTopMatchButChildrenNotMatch() {
Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))
Expand All @@ -216,12 +216,12 @@ public void testTopMatchButChildrenNotMatch() {
}

@Test
public void testSubTreeMatch() {
void testSubTreeMatch() {
Plan root =
new LogicalFilter(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")),
new LogicalFilter<>(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")),
new UnboundSlot(Lists.newArrayList("b", "id")))),
new LogicalJoin(JoinType.INNER_JOIN,
new LogicalJoin(JoinType.LEFT_OUTER_JOIN,
new LogicalJoin<>(JoinType.INNER_JOIN,
new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN,
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))),
new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("c")))
Expand Down
Loading

0 comments on commit 6010be8

Please sign in to comment.