Skip to content

Commit

Permalink
[enhancement](Nereids) support other join framework in DPHyper (apach…
Browse files Browse the repository at this point in the history
…e#21835)

implement CD-A algorithm in order to support others join in DPHyper.
The algorithm details are in on the correct and complete enumeration of the core search
  • Loading branch information
keanji-x authored Jul 21, 2023
1 parent bed940b commit b76d0d8
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ public class StatementContext {
private StatementBase parsedStatement;
private ColumnAliasGenerator columnAliasGenerator;

private int joinCount = 0;
private int maxNAryInnerJoin = 0;

private boolean isDpHyp = false;
private boolean isOtherJoinReorder = false;

Expand Down Expand Up @@ -112,6 +114,16 @@ public int getMaxNAryInnerJoin() {
return maxNAryInnerJoin;
}

public void setMaxContinuousJoin(int joinCount) {
if (joinCount > this.joinCount) {
this.joinCount = joinCount;
}
}

public int getMaxContinuousJoin() {
return joinCount;
}

public boolean isDpHyp() {
return isDpHyp;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.doris.nereids.jobs.executor;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob;
Expand Down Expand Up @@ -57,9 +56,10 @@ public void execute() {
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
serializeStatUsed(cascadesContext.getConnectContext());
// DPHyp optimize
StatementContext statementContext = cascadesContext.getStatementContext();
boolean isDpHyp = getSessionVariable().enableDPHypOptimizer || statementContext.getMaxNAryInnerJoin()
> getSessionVariable().getMaxTableCountUseCascadesJoinReorder();
int maxJoinCount = cascadesContext.getMemo().countMaxContinuousJoin();
cascadesContext.getStatementContext().setMaxContinuousJoin(maxJoinCount);
boolean isDpHyp = getSessionVariable().enableDPHypOptimizer
|| maxJoinCount > getSessionVariable().getMaxTableCountUseCascadesJoinReorder();
cascadesContext.getStatementContext().setDpHyp(isDpHyp);
cascadesContext.getStatementContext().setOtherJoinReorder(false);
if (!getSessionVariable().isDisableJoinReorder() && isDpHyp) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Set;

Expand Down Expand Up @@ -66,7 +67,7 @@ public void execute() throws AnalysisException {
}

private Group optimizePlan(Group group) {
if (group.isInnerJoinGroup()) {
if (group.isValidJoinGroup()) {
return optimizeJoin(group);
}
GroupExpression rootExpr = group.getLogicalExpression();
Expand Down Expand Up @@ -111,19 +112,19 @@ private Group optimizeJoin(Group group) {
* @param group root group, should be join type
* @param hyperGraph build hyperGraph
*/
public void buildGraph(Group group, HyperGraph hyperGraph) {
public BitSet buildGraph(Group group, HyperGraph hyperGraph) {
if (group.isProjectGroup()) {
buildGraph(group.getLogicalExpression().child(0), hyperGraph);
BitSet edgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
processProjectPlan(hyperGraph, group);
return;
return edgeMap;
}
if (!group.isInnerJoinGroup()) {
if (!group.isValidJoinGroup()) {
hyperGraph.addNode(optimizePlan(group));
return;
return new BitSet();
}
buildGraph(group.getLogicalExpression().child(0), hyperGraph);
buildGraph(group.getLogicalExpression().child(1), hyperGraph);
hyperGraph.addEdge(group);
BitSet leftEdgeMap = buildGraph(group.getLogicalExpression().child(0), hyperGraph);
BitSet rightEdgeMap = buildGraph(group.getLogicalExpression().child(1), hyperGraph);
return hyperGraph.addEdge(group, leftEdgeMap, rightEdgeMap);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -112,7 +113,7 @@ public boolean addAlias(Alias alias) {
* @param group The group that is the end node in graph
*/
public void addNode(Group group) {
Preconditions.checkArgument(!group.isInnerJoinGroup());
Preconditions.checkArgument(!group.isValidJoinGroup());
for (Slot slot : group.getLogicalExpression().getPlan().getOutput()) {
Preconditions.checkArgument(!slotToNodeMap.containsKey(slot));
slotToNodeMap.put(slot, LongBitmap.newBitmap(nodes.size()));
Expand All @@ -134,10 +135,11 @@ public HashMap<Long, List<NamedExpression>> getComplexProject() {
*
* @param group The join group
*/
public void addEdge(Group group) {
Preconditions.checkArgument(group.isInnerJoinGroup());
public BitSet addEdge(Group group, BitSet leftEdgeMap, BitSet rightEdgeMap) {
Preconditions.checkArgument(group.isValidJoinGroup());
LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin) group.getLogicalExpression().getPlan();
HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();

for (Expression expression : join.getHashJoinConjuncts()) {
Pair<Long, Long> ends = findEnds(expression);
if (!conjuncts.containsKey(ends)) {
Expand All @@ -152,25 +154,61 @@ public void addEdge(Group group) {
}
conjuncts.get(ends).second.add(expression);
}

BitSet edgeMap = new BitSet();
edgeMap.or(leftEdgeMap);
edgeMap.or(rightEdgeMap);

for (Map.Entry<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> entry : conjuncts
.entrySet()) {
LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first,
entry.getValue().second, JoinHint.NONE, join.left(), join.right());
entry.getValue().second, JoinHint.NONE, join.getMarkJoinSlotReference(),
Lists.newArrayList(join.left(), join.right()));
Edge edge = new Edge(singleJoin, edges.size());
Pair<Long, Long> ends = entry.getKey();
edge.setLeft(ends.first);
edge.setOriginalLeft(ends.first);
edge.setRight(ends.second);
edge.setOriginalRight(ends.second);
initEdgeEnds(ends, edge, leftEdgeMap, rightEdgeMap);
for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) {
nodes.get(nodeIndex).attachEdge(edge);
}
edgeMap.set(edge.getIndex());
edges.add(edge);
}

return edgeMap;
// In MySQL, each edge is reversed and store in edges again for reducing the branch miss
// We don't implement this trick now.
}

// Make edge with CD-A algorithm in
// On the correct and complete enumeration of the core search
private void initEdgeEnds(Pair<Long, Long> ends, Edge edge, BitSet leftEdges, BitSet rightEdges) {
long left = ends.first;
long right = ends.second;
for (int i = leftEdges.nextSetBit(0); i >= 0; i = leftEdges.nextSetBit(i + 1)) {
Edge lEdge = edges.get(i);
if (!JoinType.isAssoc(lEdge.getJoinType(), edge.getJoinType())) {
left = LongBitmap.or(left, lEdge.getLeft());
}
if (!JoinType.isLAssoc(lEdge.getJoinType(), edge.getJoinType())) {
left = LongBitmap.or(left, lEdge.getRight());
}
}
for (int i = rightEdges.nextSetBit(0); i >= 0; i = rightEdges.nextSetBit(i + 1)) {
Edge rEdge = edges.get(i);
if (!JoinType.isAssoc(rEdge.getJoinType(), edge.getJoinType())) {
right = LongBitmap.or(right, rEdge.getRight());
}
if (!JoinType.isRAssoc(rEdge.getJoinType(), edge.getJoinType())) {
right = LongBitmap.or(right, rEdge.getLeft());
}
}

edge.setOriginalLeft(left);
edge.setOriginalRight(right);
edge.setLeft(left);
edge.setRight(right);
}

private int findRoot(List<Integer> parent, int idx) {
int root = parent.get(idx);
if (root != idx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

/**
* The Receiver is used for cached the plan that has been emitted and build the new plan
Expand Down Expand Up @@ -117,6 +118,9 @@ public boolean emitCsgCmp(long left, long right, List<Edge> edges) {
List<Expression> hashConjuncts = new ArrayList<>();
List<Expression> otherConjuncts = new ArrayList<>();
JoinType joinType = extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
if (joinType == null) {
return true;
}
long fullKey = LongBitmap.newBitmapUnion(left, right);

List<Plan> physicalJoins = proposeAllPhysicalJoins(joinType, leftPlan, rightPlan, hashConjuncts,
Expand Down Expand Up @@ -207,30 +211,37 @@ private List<Plan> proposeAllPhysicalJoins(JoinType joinType, Plan left, Plan ri
// Check whether only NSL can be performed
LogicalProperties joinProperties = new LogicalProperties(
() -> JoinUtils.getJoinOutput(joinType, left, right));
List<Plan> plans = Lists.newArrayList();
if (JoinUtils.shouldNestedLoopJoin(joinType, hashConjuncts)) {
return Lists.newArrayList(
new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
plans.add(new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
Optional.empty(), joinProperties,
left, right),
new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
joinProperties,
right, left));
left, right));
if (joinType.isSwapJoinType()) {
plans.add(new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
joinProperties,
right, left));
}
} else {
return Lists.newArrayList(
new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, Optional.empty(),
joinProperties,
left, right),
new PhysicalHashJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, JoinHint.NONE,
Optional.empty(),
joinProperties,
right, left));
plans.add(new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, Optional.empty(),
joinProperties,
left, right));
if (joinType.isSwapJoinType()) {
plans.add(new PhysicalHashJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, JoinHint.NONE,
Optional.empty(),
joinProperties,
right, left));
}
}
return plans;
}

private JoinType extractJoinTypeAndConjuncts(List<Edge> edges, List<Expression> hashConjuncts,
private @Nullable JoinType extractJoinTypeAndConjuncts(List<Edge> edges, List<Expression> hashConjuncts,
List<Expression> otherConjuncts) {
JoinType joinType = null;
for (Edge edge : edges) {
if (edge.getJoinType() != joinType && joinType != null) {
return null;
}
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
joinType = edge.getJoinType();
for (Expression expression : edge.getExpressions()) {
Expand Down
14 changes: 4 additions & 10 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
Expand Down Expand Up @@ -374,16 +373,11 @@ public void mergeTo(Group target) {
/**
* This function used to check whether the group is an end node in DPHyp
*/
public boolean isInnerJoinGroup() {
public boolean isValidJoinGroup() {
Plan plan = getLogicalExpression().getPlan();
if (plan instanceof LogicalJoin
&& ((LogicalJoin) plan).getJoinType() == JoinType.INNER_JOIN) {
// Right now, we only support inner join
Preconditions.checkArgument(!((LogicalJoin) plan).getExpressions().isEmpty(),
"inner join must have join conjuncts");
return true;
}
return false;
return plan instanceof LogicalJoin
&& !((LogicalJoin) plan).isMarkJoin()
&& ((LogicalJoin) plan).getExpressions().size() > 0;
}

public boolean isProjectGroup() {
Expand Down
35 changes: 35 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,41 @@ private Plan skipProjectGetChild(Plan plan) {
return plan;
}

public int countMaxContinuousJoin() {
return countGroupJoin(root).second;
}

/**
* return the max continuous join operator
*/

public Pair<Integer, Integer> countGroupJoin(Group group) {
GroupExpression logicalExpr = group.getLogicalExpression();
List<Pair<Integer, Integer>> children = new ArrayList<>();
for (Group child : logicalExpr.children()) {
children.add(countGroupJoin(child));
}

if (group.isProjectGroup()) {
return children.get(0);
}

int maxJoinCount = 0;
int continuousJoinCount = 0;
for (Pair<Integer, Integer> child : children) {
maxJoinCount = Math.max(maxJoinCount, child.second);
}
if (group.isValidJoinGroup()) {
for (Pair<Integer, Integer> child : children) {
continuousJoinCount += child.first;
}
continuousJoinCount += 1;
} else if (group.isProjectGroup()) {
return children.get(0);
}
return Pair.of(continuousJoinCount, Math.max(continuousJoinCount, maxJoinCount));
}

/**
* Add plan to Memo.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ public class RuleSet {

public static final List<Rule> DPHYP_REORDER_RULES = ImmutableList.<Rule>builder()
.add(JoinCommute.BUSHY.build())
.addAll(OTHER_REORDER_RULES)
.build();

public List<Rule> getDPHypReorderRules() {
Expand Down
Loading

0 comments on commit b76d0d8

Please sign in to comment.