Skip to content

Commit

Permalink
[Feature](Nereids) support MarkJoin (apache#16616)
Browse files Browse the repository at this point in the history
# Proposed changes
1.The new optimizer supports the combination of subquery and disjunction.In the way of MarkJoin, it behaves the same as the old optimizer. For design details see:https://emmymiao87.github.io/jekyll/update/2021/07/25/Mark-Join.html.
2.Implicit type conversion is performed when conjects are generated after subquery parsing
3.Convert the unnesting of scalarSubquery in filter from filter+join to join + Conjuncts.
  • Loading branch information
zhengshiJ authored and yagagagaga committed Mar 9, 2023
1 parent 1729b2b commit 33da4a4
Show file tree
Hide file tree
Showing 78 changed files with 1,531 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,10 @@ private static void processOneSubquery(SelectStmt stmt,
+ "expression: "
+ exprWithSubquery.toSql());
}
if (exprWithSubquery instanceof BinaryPredicate && (childrenContainInOrExists(exprWithSubquery))) {
throw new AnalysisException("Not support binaryOperator children at least one is in or exists subquery"
+ exprWithSubquery.toSql());
}

if (exprWithSubquery instanceof ExistsPredicate) {
// Check if we can determine the result of an ExistsPredicate during analysis.
Expand Down Expand Up @@ -542,6 +546,16 @@ private static void processOneSubquery(SelectStmt stmt,
}
}

private static boolean childrenContainInOrExists(Expr expr) {
boolean contain = false;
for (Expr child : expr.getChildren()) {
contain = contain || child instanceof InPredicate || child instanceof ExistsPredicate;
if (contain) {
break;
}
}
return contain;
}

/**
* Replace an ExistsPredicate that contains a subquery with a BoolLiteral if we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.analysis.StatementBase;
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.rules.analysis.ColumnAliasGenerator;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.qe.ConnectContext;
Expand All @@ -28,7 +29,9 @@
import com.google.common.base.Suppliers;
import com.google.common.collect.Maps;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.annotation.concurrent.GuardedBy;

/**
Expand All @@ -51,6 +54,10 @@ public class StatementContext {

private StatementBase parsedStatement;

private Set<String> columnNames;

private ColumnAliasGenerator columnAliasGenerator;

public StatementContext() {
this.connectContext = ConnectContext.get();
}
Expand Down Expand Up @@ -111,4 +118,22 @@ public synchronized <T> T getOrRegisterCache(String key, Supplier<T> cacheSuppli
}
return supplier.get();
}

public Set<String> getColumnNames() {
return columnNames == null ? new HashSet<>() : columnNames;
}

public void setColumnNames(Set<String> columnNames) {
this.columnNames = columnNames;
}

public ColumnAliasGenerator getColumnAliasGenerator() {
return columnAliasGenerator == null
? columnAliasGenerator = new ColumnAliasGenerator(this)
: columnAliasGenerator;
}

public String generateColumnName() {
return getColumnAliasGenerator().getNextAlias();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
Expand Down Expand Up @@ -193,6 +194,12 @@ public Expr visitSlotReference(SlotReference slotReference, PlanTranslatorContex
return context.findSlotRef(slotReference.getExprId());
}

@Override
public Expr visitMarkJoinReference(MarkJoinSlotReference markJoinSlotReference, PlanTranslatorContext context) {
return markJoinSlotReference.isExistsHasAgg()
? new BoolLiteral(true) : context.findSlotRef(markJoinSlotReference.getExprId());
}

@Override
public Expr visitLiteral(Literal literal, PlanTranslatorContext context) {
return literal.toLegacyLiteral();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
Expand Down Expand Up @@ -956,7 +957,7 @@ public PlanFragment visitPhysicalHashJoin(

HashJoinNode hashJoinNode = new HashJoinNode(context.nextPlanNodeId(), leftPlanRoot,
rightPlanRoot, JoinType.toJoinOperator(joinType), execEqConjuncts, Lists.newArrayList(),
null, null, null);
null, null, null, hashJoin.isMarkJoin());

PlanFragment currentFragment;
if (JoinUtils.shouldColocateJoin(physicalHashJoin)) {
Expand Down Expand Up @@ -1012,13 +1013,15 @@ public PlanFragment visitPhysicalHashJoin(
.forEach(s -> hashOutputSlotReferenceMap.put(s.getExprId(), s));

Map<ExprId, SlotReference> leftChildOutputMap = Maps.newHashMap();
hashJoin.child(0).getOutput().stream()
Stream.concat(hashJoin.child(0).getOutput().stream(), hashJoin.child(0).getNonUserVisibleOutput().stream())
.map(SlotReference.class::cast)
.forEach(s -> leftChildOutputMap.put(s.getExprId(), s));
context.getOutputMarkJoinSlot().stream().forEach(s -> leftChildOutputMap.put(s.getExprId(), s));
Map<ExprId, SlotReference> rightChildOutputMap = Maps.newHashMap();
hashJoin.child(1).getOutput().stream()
Stream.concat(hashJoin.child(1).getOutput().stream(), hashJoin.child(1).getNonUserVisibleOutput().stream())
.map(SlotReference.class::cast)
.forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
context.getOutputMarkJoinSlot().stream().forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
// translate runtime filter
context.getRuntimeTranslator().ifPresent(runtimeFilterTranslator -> runtimeFilterTranslator
.getRuntimeFilterOfHashJoinNode(physicalHashJoin)
Expand All @@ -1040,6 +1043,9 @@ public PlanFragment visitPhysicalHashJoin(
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
leftIntermediateSlotDescriptor.add(sd);
if (sf instanceof MarkJoinSlotReference && hashJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(sf);
}
}
} else if (hashJoin.getOtherJoinConjuncts().isEmpty()
&& (joinType == JoinType.RIGHT_ANTI_JOIN || joinType == JoinType.RIGHT_SEMI_JOIN)) {
Expand Down Expand Up @@ -1076,6 +1082,14 @@ public PlanFragment visitPhysicalHashJoin(
}
}

if (hashJoin.getMarkJoinSlotReference().isPresent()) {
if (hashJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(hashJoin.getMarkJoinSlotReference().get());
context.setOutputMarkJoinSlot(hashJoin.getMarkJoinSlotReference().get());
}
context.createSlotDesc(intermediateDescriptor, hashJoin.getMarkJoinSlotReference().get());
}

// set slots as nullable for outer join
if (joinType == JoinType.LEFT_OUTER_JOIN || joinType == JoinType.FULL_OUTER_JOIN) {
rightIntermediateSlotDescriptor.forEach(sd -> sd.setIsNullable(true));
Expand Down Expand Up @@ -1142,7 +1156,7 @@ public PlanFragment visitPhysicalNestedLoopJoin(

NestedLoopJoinNode nestedLoopJoinNode = new NestedLoopJoinNode(context.nextPlanNodeId(),
leftFragmentPlanRoot, rightFragmentPlanRoot, tupleIds, JoinType.toJoinOperator(joinType),
null, null, null);
null, null, null, nestedLoopJoin.isMarkJoin());
if (nestedLoopJoin.getStats() != null) {
nestedLoopJoinNode.setCardinality((long) nestedLoopJoin.getStats().getRowCount());
}
Expand All @@ -1157,13 +1171,17 @@ public PlanFragment visitPhysicalNestedLoopJoin(
}

Map<ExprId, SlotReference> leftChildOutputMap = Maps.newHashMap();
nestedLoopJoin.child(0).getOutput().stream()
Stream.concat(nestedLoopJoin.child(0).getOutput().stream(),
nestedLoopJoin.child(0).getNonUserVisibleOutput().stream())
.map(SlotReference.class::cast)
.forEach(s -> leftChildOutputMap.put(s.getExprId(), s));
context.getOutputMarkJoinSlot().stream().forEach(s -> leftChildOutputMap.put(s.getExprId(), s));
Map<ExprId, SlotReference> rightChildOutputMap = Maps.newHashMap();
nestedLoopJoin.child(1).getOutput().stream()
Stream.concat(nestedLoopJoin.child(1).getOutput().stream(),
nestedLoopJoin.child(1).getNonUserVisibleOutput().stream())
.map(SlotReference.class::cast)
.forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
context.getOutputMarkJoinSlot().stream().forEach(s -> rightChildOutputMap.put(s.getExprId(), s));
// make intermediate tuple
List<SlotDescriptor> leftIntermediateSlotDescriptor = Lists.newArrayList();
List<SlotDescriptor> rightIntermediateSlotDescriptor = Lists.newArrayList();
Expand Down Expand Up @@ -1198,6 +1216,7 @@ public PlanFragment visitPhysicalNestedLoopJoin(
.map(outputSlotReferenceMap::get)
.filter(Objects::nonNull)
.collect(Collectors.toList());

// TODO: because of the limitation of be, the VNestedLoopJoinNode will output column from both children
// in the intermediate tuple, so fe have to do the same, if be fix the problem, we can change it back.
for (SlotDescriptor leftSlotDescriptor : leftSlotDescriptors) {
Expand All @@ -1207,6 +1226,9 @@ public PlanFragment visitPhysicalNestedLoopJoin(
SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
leftIntermediateSlotDescriptor.add(sd);
if (sf instanceof MarkJoinSlotReference && nestedLoopJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(sf);
}
}
for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) {
if (!rightSlotDescriptor.isMaterialized()) {
Expand All @@ -1215,6 +1237,17 @@ public PlanFragment visitPhysicalNestedLoopJoin(
SlotReference sf = rightChildOutputMap.get(context.findExprId(rightSlotDescriptor.getId()));
SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf);
rightIntermediateSlotDescriptor.add(sd);
if (sf instanceof MarkJoinSlotReference && nestedLoopJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(sf);
}
}

if (nestedLoopJoin.getMarkJoinSlotReference().isPresent()) {
if (nestedLoopJoin.getFilterConjuncts().isEmpty()) {
outputSlotReferences.add(nestedLoopJoin.getMarkJoinSlotReference().get());
context.setOutputMarkJoinSlot(nestedLoopJoin.getMarkJoinSlotReference().get());
}
context.createSlotDesc(intermediateDescriptor, nestedLoopJoin.getMarkJoinSlotReference().get());
}

// set slots as nullable for outer join
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
Expand Down Expand Up @@ -81,6 +82,8 @@ public class PlanTranslatorContext {
private final Map<ExprId, SlotRef> bufferedSlotRefForWindow = Maps.newHashMap();
private TupleDescriptor bufferedTupleForWindow = null;

private List<MarkJoinSlotReference> outputMarkJoinSlot = Lists.newArrayList();

public PlanTranslatorContext(CascadesContext ctx) {
this.translator = new RuntimeFilterTranslator(ctx.getRuntimeFilterContext());
}
Expand Down Expand Up @@ -210,4 +213,12 @@ public TupleDescriptor getTupleDesc(TupleId tupleId) {
public DescriptorTable getDescTable() {
return descTable;
}

public void setOutputMarkJoinSlot(MarkJoinSlotReference markJoinSlotReference) {
outputMarkJoinSlot.add(markJoinSlotReference);
}

public List<MarkJoinSlotReference> getOutputMarkJoinSlot() {
return outputMarkJoinSlot;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ public class NereidsRewriter extends BatchRewriteJob {
new AvgDistinctToSumDivCount(),
new CountDistinctRewrite(),

new NormalizeAggregate(),
new ExtractFilterFromCrossJoin()
),

Expand Down Expand Up @@ -116,6 +115,14 @@ public class NereidsRewriter extends BatchRewriteJob {
)
),

// The rule modification needs to be done after the subquery is unnested,
// because for scalarSubQuery, the connection condition is stored in apply in the analyzer phase,
// but when normalizeAggregate is performed, the members in apply cannot be obtained,
// resulting in inconsistent output results and results in apply
topDown(
new NormalizeAggregate()
),

topDown(
new AdjustAggregateNullableForEmptySet()
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ private double calCost(Edge edge, StatsDeriveResult stats,
join.getJoinType(),
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),
join.left(),
join.right());
Expand All @@ -442,6 +443,7 @@ private double calCost(Edge edge, StatsDeriveResult stats,
join.getHashJoinConjuncts(),
join.getOtherJoinConjuncts(),
join.getHint(),
join.getMarkJoinSlotReference(),
join.getLogicalProperties(),
join.left(),
join.right());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -226,16 +227,19 @@ private List<Plan> proposeAllPhysicalJoins(JoinType joinType, Plan left, Plan ri
() -> JoinUtils.getJoinOutput(joinType, left, right));
if (JoinUtils.shouldNestedLoopJoin(joinType, hashConjuncts)) {
return Lists.newArrayList(
new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts, joinProperties, left,
right),
new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, joinProperties,
new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
Optional.empty(), joinProperties,
left, right),
new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
joinProperties,
right, left));
} else {
return Lists.newArrayList(
new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, joinProperties,
left,
right),
new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts, JoinHint.NONE, Optional.empty(),
joinProperties,
left, right),
new PhysicalHashJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, JoinHint.NONE,
Optional.empty(),
joinProperties,
right, left));
}
Expand All @@ -258,6 +262,17 @@ private JoinType extractJoinTypeAndConjuncts(List<Edge> edges, List<Expression>
return joinType;
}

private boolean extractIsMarkJoin(List<Edge> edges) {
boolean isMarkJoin = false;
JoinType joinType = null;
for (Edge edge : edges) {
Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType());
isMarkJoin = edge.getJoin().isMarkJoin() || isMarkJoin;
joinType = edge.getJoinType();
}
return isMarkJoin;
}

@Override
public void addGroup(long bitmap, Group group) {
Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 1);
Expand Down Expand Up @@ -322,8 +337,8 @@ private void makeLogicalExpression(Group root) {
} else if (physicalPlan instanceof AbstractPhysicalJoin) {
AbstractPhysicalJoin physicalJoin = (AbstractPhysicalJoin) physicalPlan;
logicalPlan = new LogicalJoin<>(physicalJoin.getJoinType(), physicalJoin.getHashJoinConjuncts(),
physicalJoin.getOtherJoinConjuncts(), JoinHint.NONE, physicalJoin.child(0),
physicalJoin.child(1));
physicalJoin.getOtherJoinConjuncts(), JoinHint.NONE, physicalJoin.getMarkJoinSlotReference(),
physicalJoin.child(0), physicalJoin.child(1));
} else {
throw new RuntimeException("DPhyp can only handle join and project operator");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,7 @@ public LogicalPlan visitFromClause(FromClauseContext ctx) {
ExpressionUtils.EMPTY_CONDITION,
ExpressionUtils.EMPTY_CONDITION,
JoinHint.NONE,
Optional.empty(),
left,
right);
left = withJoinRelations(left, relation);
Expand Down Expand Up @@ -1481,6 +1482,7 @@ private LogicalPlan withJoinRelations(LogicalPlan input, RelationContext ctx) {
condition.map(ExpressionUtils::extractConjunction)
.orElse(ExpressionUtils.EMPTY_CONDITION),
joinHint,
Optional.empty(),
last,
plan(join.relationPrimary()));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public PhysicalPlan visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? ext
Map<NamedExpression, Pair<RelationId, Slot>> aliasTransferMap = ctx.getAliasTransferMap();
join.right().accept(this, context);
join.left().accept(this, context);
if (deniedJoinType.contains(join.getJoinType())) {
if (deniedJoinType.contains(join.getJoinType()) || join.isMarkJoin()) {
// copy to avoid bug when next call of getOutputSet()
Set<Slot> slots = join.getOutputSet();
slots.forEach(aliasTransferMap::remove);
Expand Down
Loading

0 comments on commit 33da4a4

Please sign in to comment.