Skip to content

Commit

Permalink
[enhancement](Nereids): optimize GroupExpressionMatching and GroupMat…
Browse files Browse the repository at this point in the history
…ching
  • Loading branch information
jackwener committed Nov 1, 2023
1 parent eaed0de commit 0a39727
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public ApplyRuleJob(GroupExpression groupExpression, Rule rule, JobContext conte
}

@Override
public void execute() throws AnalysisException {
public final void execute() throws AnalysisException {
if (groupExpression.hasApplied(rule)
|| groupExpression.isUnused()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -55,6 +54,7 @@ public GroupExpressionIterator iterator() {
public static class GroupExpressionIterator implements Iterator<Plan> {
private final List<Plan> results = Lists.newArrayList();
private int resultIndex = 0;
private int resultsSize;

/**
* Constructor.
Expand All @@ -69,21 +69,19 @@ public GroupExpressionIterator(Pattern<Plan> pattern, GroupExpression groupExpre

int childrenGroupArity = groupExpression.arity();
int patternArity = pattern.arity();
if (!(pattern instanceof SubTreePattern)) {
// (logicalFilter(), multi()) match (logicalFilter()),
// but (logicalFilter(), logicalFilter(), multi()) not match (logicalFilter())
boolean extraMulti = patternArity == childrenGroupArity + 1
&& (pattern.hasMultiChild() || pattern.hasMultiGroupChild());
if (patternArity > childrenGroupArity && !extraMulti) {
return;
}
// (logicalFilter(), multi()) match (logicalFilter()),
// but (logicalFilter(), logicalFilter(), multi()) not match (logicalFilter())
boolean extraMulti = patternArity == childrenGroupArity + 1
&& (pattern.hasMultiChild() || pattern.hasMultiGroupChild());
if (patternArity > childrenGroupArity && !extraMulti) {
return;
}

// (multi()) match (logicalFilter(), logicalFilter()),
// but (logicalFilter()) not match (logicalFilter(), logicalFilter())
if (!pattern.isAny() && patternArity < childrenGroupArity
&& !pattern.hasMultiChild() && !pattern.hasMultiGroupChild()) {
return;
}
// (multi()) match (logicalFilter(), logicalFilter()),
// but (logicalFilter()) not match (logicalFilter(), logicalFilter())
if (!pattern.isAny() && patternArity < childrenGroupArity
&& !pattern.hasMultiChild() && !pattern.hasMultiGroupChild()) {
return;
}

// Pattern.GROUP / Pattern.MULTI / Pattern.MULTI_GROUP can not match GroupExpression
Expand All @@ -93,7 +91,7 @@ public GroupExpressionIterator(Pattern<Plan> pattern, GroupExpression groupExpre

// getPlan return the plan with GroupPlan as children
Plan root = groupExpression.getPlan();
if (patternArity == 0 && !(pattern instanceof SubTreePattern)) {
if (patternArity == 0) {
if (pattern.matchPredicates(root)) {
// if no children pattern, we treat all children as GROUP. e.g. Pattern.ANY.
// leaf plan will enter this branch too, e.g. logicalRelation().
Expand All @@ -103,20 +101,16 @@ public GroupExpressionIterator(Pattern<Plan> pattern, GroupExpression groupExpre
// matching children group, one List<Plan> per child
// first dimension is every child group's plan
// second dimension is all matched plan in one group
List<List<Plan>> childrenPlans = Lists.newArrayListWithCapacity(childrenGroupArity);
List<Plan>[] childrenPlans = new List[childrenGroupArity];
for (int i = 0; i < childrenGroupArity; ++i) {
Group childGroup = groupExpression.child(i);
List<Plan> childrenPlan = matchingChildGroup(pattern, childGroup, i);

if (childrenPlan.isEmpty()) {
if (pattern instanceof SubTreePattern) {
childrenPlan = ImmutableList.of(new GroupPlan(childGroup));
} else {
// current pattern is match but children patterns not match
return;
}
// current pattern is match but children patterns not match
return;
}
childrenPlans.add(childrenPlan);
childrenPlans[i] = childrenPlan;
}
assembleAllCombinationPlanTree(root, pattern, groupExpression, childrenPlans);
} else if (patternArity == 1 && (pattern.hasMultiChild() || pattern.hasMultiGroupChild())) {
Expand All @@ -127,25 +121,22 @@ public GroupExpressionIterator(Pattern<Plan> pattern, GroupExpression groupExpre
results.add(root);
}
}
this.resultsSize = results.size();
}

private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
Group childGroup, int childIndex) {
Pattern<? extends Plan> childPattern;
if (parentPattern instanceof SubTreePattern) {
childPattern = parentPattern;
} else {
boolean isLastPattern = childIndex + 1 >= parentPattern.arity();
int patternChildIndex = isLastPattern ? parentPattern.arity() - 1 : childIndex;

childPattern = parentPattern.child(patternChildIndex);
// translate MULTI and MULTI_GROUP to ANY and GROUP
if (isLastPattern) {
if (childPattern.isMulti()) {
childPattern = Pattern.ANY;
} else if (childPattern.isMultiGroup()) {
childPattern = Pattern.GROUP;
}
boolean isLastPattern = childIndex + 1 >= parentPattern.arity();
int patternChildIndex = isLastPattern ? parentPattern.arity() - 1 : childIndex;

childPattern = parentPattern.child(patternChildIndex);
// translate MULTI and MULTI_GROUP to ANY and GROUP
if (isLastPattern) {
if (childPattern.isMulti()) {
childPattern = Pattern.ANY;
} else if (childPattern.isMultiGroup()) {
childPattern = Pattern.GROUP;
}
}

Expand All @@ -154,48 +145,45 @@ private List<Plan> matchingChildGroup(Pattern<? extends Plan> parentPattern,
}

private void assembleAllCombinationPlanTree(Plan root, Pattern<Plan> rootPattern,
GroupExpression groupExpression,
List<List<Plan>> childrenPlans) {
int[] childrenPlanIndex = new int[childrenPlans.size()];
GroupExpression groupExpression, List<Plan>[] childrenPlans) {
int childrenPlansSize = childrenPlans.length;
int[] childrenPlanIndex = new int[childrenPlansSize];
int offset = 0;
LogicalProperties logicalProperties = groupExpression.getOwnerGroup().getLogicalProperties();

// assemble all combination of plan tree by current root plan and children plan
while (offset < childrenPlans.size()) {
ImmutableList.Builder<Plan> childrenBuilder =
ImmutableList.builderWithExpectedSize(childrenPlans.size());
for (int i = 0; i < childrenPlans.size(); i++) {
childrenBuilder.add(childrenPlans.get(i).get(childrenPlanIndex[i]));
Optional<GroupExpression> groupExprOption = Optional.of(groupExpression);
Optional<LogicalProperties> logicalPropOption = Optional.of(logicalProperties);
while (offset < childrenPlansSize) {
ImmutableList.Builder<Plan> childrenBuilder = ImmutableList.builderWithExpectedSize(childrenPlansSize);
for (int i = 0; i < childrenPlansSize; i++) {
childrenBuilder.add(childrenPlans[i].get(childrenPlanIndex[i]));
}
List<Plan> children = childrenBuilder.build();

// assemble children: replace GroupPlan to real plan,
// withChildren will erase groupExpression, so we must
// withGroupExpression too.
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(Optional.of(groupExpression),
Optional.of(logicalProperties), children);
Plan rootWithChildren = root.withGroupExprLogicalPropChildren(groupExprOption,
logicalPropOption, children);
if (rootPattern.matchPredicates(rootWithChildren)) {
results.add(rootWithChildren);
}
offset = 0;
while (true) {
for (offset = 0; offset < childrenPlansSize; offset++) {
childrenPlanIndex[offset]++;
if (childrenPlanIndex[offset] == childrenPlans.get(offset).size()) {
if (childrenPlanIndex[offset] == childrenPlans[offset].size()) {
// Reset the index when it reaches the size of the current child plan list
childrenPlanIndex[offset] = 0;
offset++;
if (offset == childrenPlans.size()) {
break;
}
} else {
break;
break; // Break the loop when the index is within the size of the current child plan list
}
}
}
}

@Override
public boolean hasNext() {
return resultIndex < results.size();
return resultIndex < resultsSize;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@
public class EnforceMissingPropertiesHelper {
private static final EventProducer ENFORCER_TRACER = new EventProducer(EnforcerEvent.class,
EventChannel.getDefaultChannel().addConsumers(new LogConsumer(EnforcerEvent.class, EventChannel.LOG)));
private final JobContext context;
private final GroupExpression groupExpression;
private Cost curTotalCost;

public EnforceMissingPropertiesHelper(JobContext context, GroupExpression groupExpression,
Cost curTotalCost) {
this.context = context;
this.groupExpression = groupExpression;
this.curTotalCost = curTotalCost;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@

package org.apache.doris.nereids.trees;

import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.planner.PlanNodeId;

import com.google.common.collect.ImmutableList;

import java.util.List;
Expand All @@ -33,7 +29,6 @@
*/
public abstract class AbstractTreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>>
implements TreeNode<NODE_TYPE> {
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();
protected final List<NODE_TYPE> children;
// TODO: Maybe we should use a GroupPlan to avoid TreeNode hold the GroupExpression.
// https://github.com/apache/doris/pull/9807#discussion_r884829067
Expand All @@ -59,12 +54,4 @@ public List<NODE_TYPE> children() {
public int arity() {
return children.size();
}

/**
* used for PhysicalPlanTranslator only
* @return PlanNodeId
*/
public PlanNodeId translatePlanNodeId() {
return id.toPlanNodeId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@
import org.apache.doris.nereids.util.Utils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Abstract class for all Expression in Nereids.
Expand Down Expand Up @@ -247,8 +247,19 @@ public final Set<Slot> getInputSlots() {
return collect(Slot.class::isInstance);
}

/**
* Get all the input slot ids of the expression.
* <p>
* Note that the input slots of subquery's inner plan is not included.
*/
public final Set<ExprId> getInputSlotExprIds() {
return getInputSlots().stream().map(NamedExpression::getExprId).collect(Collectors.toSet());
ImmutableSet.Builder<ExprId> result = ImmutableSet.builder();
foreach(node -> {
if (node instanceof Slot) {
result.add(((Slot) node).getExprId());
}
});
return result.build();
}

public boolean isLiteral() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.util.MutableState;
import org.apache.doris.nereids.util.MutableState.EmptyMutableState;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.planner.PlanNodeId;
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Supplier;
Expand All @@ -45,6 +47,7 @@
*/
public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Plan {
public static final String FRAGMENT_ID = "fragment";
protected final ObjectId id = StatementScopeIdGenerator.newObjectId();

protected final Statistics statistics;
protected final PlanType type;
Expand Down Expand Up @@ -168,4 +171,12 @@ public Optional<Object> getMutableState(String key) {
public void setMutableState(String key, Object state) {
this.mutableState = this.mutableState.set(key, state);
}

/**
* used for PhysicalPlanTranslator only
* @return PlanNodeId
*/
public PlanNodeId translatePlanNodeId() {
return id.toPlanNodeId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
import org.apache.doris.thrift.TRuntimeFilterType;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -113,22 +113,25 @@ private PhysicalHashJoin(
* Return pair of left used slots and right used slots.
*/
public Pair<List<ExprId>, List<ExprId>> getHashConjunctsExprIds() {
List<ExprId> exprIds1 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
List<ExprId> exprIds2 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size());
int size = hashJoinConjuncts.size();

List<ExprId> exprIds1 = new ArrayList<>(size);
List<ExprId> exprIds2 = new ArrayList<>(size);

Set<ExprId> leftExprIds = left().getOutputExprIdSet();
Set<ExprId> rightExprIds = right().getOutputExprIdSet();

for (Expression expr : hashJoinConjuncts) {
expr.getInputSlotExprIds().forEach(exprId -> {
for (ExprId exprId : expr.getInputSlotExprIds()) {
if (leftExprIds.contains(exprId)) {
exprIds1.add(exprId);
} else if (rightExprIds.contains(exprId)) {
exprIds2.add(exprId);
} else {
throw new RuntimeException("Could not generate valid equal on clause slot pairs for join");
throw new RuntimeException("Invalid ExprId found: " + exprId
+ ". Cannot generate valid equal on clause slot pairs for join.");
}
});
}
}
return Pair.of(exprIds1, exprIds2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -66,9 +67,10 @@ public static boolean couldBroadcast(Join join) {
* check if the row count of the left child in the broadcast join is less than a threshold value.
*/
public static boolean checkBroadcastJoinStats(PhysicalHashJoin<? extends Plan, ? extends Plan> join) {
double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte();
double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit();
double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage();
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
double memLimit = sessionVariable.getMaxExecMemByte();
double rowsLimit = sessionVariable.getBroadcastRowCountLimit();
double brMemlimit = sessionVariable.getBroadcastHashtableMemLimitPercentage();
double datasize = join.getGroupExpression().get().child(1).getStatistics().computeSize();
double rowCount = join.getGroupExpression().get().child(1).getStatistics().getRowCount();
return rowCount <= rowsLimit && datasize <= memLimit * brMemlimit;
Expand Down Expand Up @@ -114,12 +116,12 @@ boolean isCoveredByRightSlots(ExprId slot) {
* @return true if the equal can be used as hash join condition
*/
public boolean isHashJoinCondition(EqualTo equalTo) {
Set<Slot> equalLeft = equalTo.left().collect(Slot.class::isInstance);
Set<Slot> equalLeft = equalTo.left().getInputSlots();
if (equalLeft.isEmpty()) {
return false;
}

Set<Slot> equalRight = equalTo.right().collect(Slot.class::isInstance);
Set<Slot> equalRight = equalTo.right().getInputSlots();
if (equalRight.isEmpty()) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static Plan filterOrSelf(Set<Expression> predicates, Plan plan) {
* normalize comparison predicate on a binary plan to its two sides are corresponding to the child's output.
*/
public static ComparisonPredicate maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
Set<Slot> slots = expression.left().collect(Slot.class::isInstance);
Set<Slot> slots = expression.left().getInputSlots();
Set<Slot> leftSlots = left.getOutputSet();
Set<Slot> buffer = Sets.newHashSet(slots);
buffer.removeAll(leftSlots);
Expand Down
Loading

0 comments on commit 0a39727

Please sign in to comment.