Skip to content

Commit

Permalink
EqualPredicate
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener committed Nov 20, 2023
1 parent 80c75b6 commit 248aff9
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
Expand Down Expand Up @@ -1138,7 +1138,7 @@ public PlanFragment visitPhysicalHashJoin(
JoinType joinType = hashJoin.getJoinType();

List<Expr> execEqConjuncts = hashJoin.getHashJoinConjuncts().stream()
.map(EqualTo.class::cast)
.map(EqualPredicate.class::cast)
.map(e -> JoinUtils.swapEqualToForChildrenOrder(e, hashJoin.left().getOutputSet()))
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,10 @@ private void pushDownRuntimeFilterCommon(PhysicalHashJoin<? extends Plan, ? exte
List<TRuntimeFilterType> legalTypes = Arrays.stream(TRuntimeFilterType.values())
.filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0)
.collect(Collectors.toList());
for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) {
List<EqualTo> hashJoinConjuncts = join.getEqualToConjuncts();
for (int i = 0; i < hashJoinConjuncts.size(); i++) {
EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder(
(EqualTo) join.getHashJoinConjuncts().get(i), join.left().getOutputSet()));
hashJoinConjuncts.get(i), join.left().getOutputSet()));
for (TRuntimeFilterType type : legalTypes) {
//bitmap rf is generated by nested loop join.
if (type == TRuntimeFilterType.BITMAP) {
Expand Down Expand Up @@ -487,7 +488,7 @@ private void analyzeRuntimeFilterPushDownIntoCTEInfos(PhysicalHashJoin<? extends
|| !(join.getHashJoinConjuncts().get(0) instanceof EqualTo)) {
break;
} else {
EqualTo equalTo = (EqualTo) join.getHashJoinConjuncts().get(0);
EqualTo equalTo = (EqualTo) join.getEqualToConjuncts().get(0);
equalTos.add(equalTo);
equalCondToJoinMap.put(equalTo, join);
}
Expand Down Expand Up @@ -523,12 +524,11 @@ private void analyzeRuntimeFilterPushDownIntoCTEInfos(PhysicalHashJoin<? extends
// check further whether the join upper side can bring equal set, which
// indicating actually the same runtime filter build side
// see above case 2 for reference
List<Expression> conditions = curJoin.getHashJoinConjuncts();
boolean inSameEqualSet = false;
for (Expression e : conditions) {
for (EqualTo e : curJoin.getEqualToConjuncts()) {
if (e instanceof EqualTo) {
SlotReference oneSide = (SlotReference) ((EqualTo) e).left();
SlotReference anotherSide = (SlotReference) ((EqualTo) e).right();
SlotReference oneSide = (SlotReference) e.left();
SlotReference anotherSide = (SlotReference) e.right();
if (anotherSideSlotSet.contains(oneSide) && anotherSideSlotSet.contains(anotherSide)) {
inSameEqualSet = true;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
Expand Down Expand Up @@ -91,22 +92,18 @@ public Rule build() {
*/
conjunctsChanged |= join.getHashJoinConjuncts().stream()
.map(EqualTo.class::cast)
.map(equalTo ->
(EqualTo) JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet()))
.map(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts)
).anyMatch(Boolean::booleanValue);
.map(equalTo -> JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet()))
.anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts));

JoinUtils.JoinSlotCoverageChecker checker = new JoinUtils.JoinSlotCoverageChecker(
join.left().getOutput(),
join.right().getOutput());
conjunctsChanged |= join.getOtherJoinConjuncts().stream().filter(EqualTo.class::isInstance)
.map(EqualTo.class::cast)
.filter(equalTo -> checker.isHashJoinCondition(equalTo))
.map(equalTo -> (EqualTo) JoinUtils.swapEqualToForChildrenOrder(equalTo,
conjunctsChanged |= join.getOtherJoinConjuncts().stream()
.filter(EqualTo.class::isInstance)
.filter(equalTo -> checker.isHashJoinCondition((EqualPredicate) equalTo))
.map(equalTo -> JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) equalTo,
join.left().getOutputSet()))
.map(equalTo ->
createIsNotNullIfNecessary(equalTo, conjuncts))
.anyMatch(Boolean::booleanValue);
.anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts));
}
if (conjunctsChanged) {
return filter.withConjuncts(conjuncts.stream().collect(ImmutableSet.toImmutableSet()))
Expand Down Expand Up @@ -135,7 +132,7 @@ private JoinType tryEliminateOuterJoin(JoinType joinType, boolean canFilterLeftN
return joinType;
}

private boolean createIsNotNullIfNecessary(EqualTo swapedEqualTo, Collection<Expression> container) {
private boolean createIsNotNullIfNecessary(EqualPredicate swapedEqualTo, Collection<Expression> container) {
boolean containerChanged = false;
if (swapedEqualTo.left().nullable()) {
Not not = new Not(new IsNull(swapedEqualTo.left()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
Expand Down Expand Up @@ -77,11 +77,10 @@ public Rule build() {
Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
join.getHashJoinConjuncts().forEach(conjunct -> {
Preconditions.checkArgument(conjunct instanceof EqualTo);
Preconditions.checkArgument(conjunct instanceof EqualPredicate);
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
// doesn't swap the two sides.
conjunct = JoinUtils.swapEqualToForChildrenOrder(
(EqualTo) conjunct, join.left().getOutputSet());
conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet());
generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs);
generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
Expand Down Expand Up @@ -306,7 +305,7 @@ public PrefixIndexCheckResult visitInPredicate(InPredicate in, Map<ExprId, Strin

@Override
public PrefixIndexCheckResult visitComparisonPredicate(ComparisonPredicate cp, Map<ExprId, String> context) {
if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
if (cp instanceof EqualPredicate) {
return check(cp, context, PrefixIndexCheckResult::createEqual);
} else {
return check(cp, context, PrefixIndexCheckResult::createNonEqual);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
Expand All @@ -33,7 +34,6 @@
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Like;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
Expand Down Expand Up @@ -210,7 +210,7 @@ private Statistics calculateWhenLiteralRight(ComparisonPredicate cp,
return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
}

if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
if (cp instanceof EqualPredicate) {
return estimateEqualTo(cp, statsForLeft, statsForRight, context);
} else {
if (cp instanceof LessThan || cp instanceof LessThanEqual) {
Expand Down Expand Up @@ -255,7 +255,7 @@ private Statistics calculateWhenBothColumn(ComparisonPredicate cp, EstimationCon
ColumnStatistic statsForLeft, ColumnStatistic statsForRight) {
Expression left = cp.left();
Expression right = cp.right();
if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
if (cp instanceof EqualPredicate) {
return estimateColumnEqualToColumn(left, statsForLeft, right, statsForRight, context);
}
if (cp instanceof GreaterThan || cp instanceof GreaterThanEqual) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
Expand All @@ -45,14 +45,14 @@
public class JoinEstimation {
private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3;

private static EqualTo normalizeHashJoinCondition(EqualTo equalTo, Statistics leftStats, Statistics rightStats) {
boolean changeOrder = equalTo.left().getInputSlots().stream().anyMatch(
slot -> rightStats.findColumnStatistics(slot) != null
);
private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, Statistics leftStats,
Statistics rightStats) {
boolean changeOrder = equal.left().getInputSlots().stream()
.anyMatch(slot -> rightStats.findColumnStatistics(slot) != null);
if (changeOrder) {
return new EqualTo(equalTo.right(), equalTo.left());
return equal.commute();
} else {
return equalTo;
return equal;
}
}

Expand Down Expand Up @@ -81,18 +81,18 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ
* In order to avoid error propagation, for unTrustEquations, we only use the biggest selectivity.
*/
List<Double> unTrustEqualRatio = Lists.newArrayList();
List<EqualTo> unTrustableCondition = Lists.newArrayList();
List<EqualPredicate> unTrustableCondition = Lists.newArrayList();
boolean leftBigger = leftStats.getRowCount() > rightStats.getRowCount();
double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
List<EqualTo> trustableConditions = join.getHashJoinConjuncts().stream()
.map(expression -> (EqualTo) expression)
List<EqualPredicate> trustableConditions = join.getHashJoinConjuncts().stream()
.map(expression -> (EqualPredicate) expression)
.filter(
expression -> {
// since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold,
// this column is regarded as unique.
double almostUniqueThreshold = 0.9;
EqualTo equal = normalizeHashJoinCondition(expression, leftStats, rightStats);
EqualPredicate equal = normalizeHashJoinCondition(expression, leftStats, rightStats);
ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats);
ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats);
boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold
Expand Down Expand Up @@ -204,7 +204,7 @@ private static double estimateJoinConditionSel(Statistics crossJoinStats, Expres
}

private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics leftStats,
Statistics rightStats, Join join, EqualTo equalTo) {
Statistics rightStats, Join join, EqualPredicate equalTo) {
Expression eqLeft = equalTo.left();
Expression eqRight = equalTo.right();
ColumnStatistic probColStats = leftStats.findColumnStatistics(eqLeft);
Expand Down Expand Up @@ -261,7 +261,7 @@ private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics ri
double rowCount = Double.POSITIVE_INFINITY;
for (Expression conjunct : join.getHashJoinConjuncts()) {
double eqRowCount = estimateSemiOrAntiRowCountBySlotsEqual(leftStats, rightStats,
join, (EqualTo) conjunct);
join, (EqualPredicate) conjunct);
if (rowCount > eqRowCount) {
rowCount = eqRowCount;
}
Expand Down Expand Up @@ -336,7 +336,7 @@ public static Statistics estimate(Statistics leftStats, Statistics rightStats, J
private static Statistics updateJoinResultStatsByHashJoinCondition(Statistics innerStats, Join join) {
Map<Expression, ColumnStatistic> updatedCols = new HashMap<>();
for (Expression expr : join.getHashJoinConjuncts()) {
EqualTo equalTo = (EqualTo) expr;
EqualPredicate equalTo = (EqualPredicate) expr;
ColumnStatistic leftColStats = ExpressionEstimation.estimate(equalTo.left(), innerStats);
ColumnStatistic rightColStats = ExpressionEstimation.estimate(equalTo.right(), innerStats);
double minNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions;

import java.util.List;

/**
* EqualPredicate
*/
public abstract class EqualPredicate extends ComparisonPredicate {

protected EqualPredicate(List<Expression> children, String symbol) {
super(children, symbol);
}

@Override
public EqualPredicate commute() {
return null;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
/**
* Equal to expression: a = b.
*/
public class EqualTo extends ComparisonPredicate implements PropagateNullable {
public class EqualTo extends EqualPredicate implements PropagateNullable {

public EqualTo(Expression left, Expression right) {
super(ImmutableList.of(left, right), "=");
Expand All @@ -55,7 +55,7 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
}

@Override
public ComparisonPredicate commute() {
public EqualTo commute() {
return new EqualTo(right(), left());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@
* Null safe equal expression: a <=> b.
* Unlike normal equal to expression, null <=> null is true.
*/
public class NullSafeEqual extends ComparisonPredicate implements AlwaysNotNullable {
/**
* Constructor of Null Safe Equal ComparisonPredicate.
*
* @param left left child of Null Safe Equal
* @param right right child of Null Safe Equal
*/
public class NullSafeEqual extends EqualPredicate implements AlwaysNotNullable {
public NullSafeEqual(Expression left, Expression right) {
super(ImmutableList.of(left, right), "<=>");
}
Expand All @@ -61,8 +55,7 @@ public NullSafeEqual withChildren(List<Expression> children) {
}

@Override
public ComparisonPredicate commute() {
public NullSafeEqual commute() {
return new NullSafeEqual(right(), left());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
Expand Down Expand Up @@ -114,6 +115,11 @@ public List<Expression> getHashJoinConjuncts() {
return hashJoinConjuncts;
}

public List<EqualTo> getEqualToConjuncts() {
return hashJoinConjuncts.stream().filter(EqualTo.class::isInstance).map(EqualTo.class::cast)
.collect(Collectors.toList());
}

public boolean isShouldTranslateOutput() {
return shouldTranslateOutput;
}
Expand Down
Loading

0 comments on commit 248aff9

Please sign in to comment.