Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement](nereids) Support join derivation when mv rewrite #29609

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[Improvement](nereids) Support join derivation when query rewrite by …
…materialized view
  • Loading branch information
seawinde committed Jan 9, 2024
commit 04834dcad55b4200b33d27d6ec1bb2e266c47e41
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NonNullable;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
Expand All @@ -50,6 +52,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -149,7 +152,7 @@ protected List<Plan> rewrite(Plan queryPlan, CascadesContext cascadesContext) {
queryStructInfo.addPredicates(pulledUpExpressions);
}
SplitPredicate compensatePredicates = predicatesCompensate(queryStructInfo, viewStructInfo,
queryToViewSlotMapping);
queryToViewSlotMapping, comparisonResult);
// Can not compensate, bail out
if (compensatePredicates.isEmpty()) {
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Expand Down Expand Up @@ -315,20 +318,28 @@ protected List<Expression> rewriteExpression(List<? extends Expression> sourceEx

List<Expression> rewrittenExpressions = new ArrayList<>();
for (int index = 0; index < sourceShuttledExpressions.size(); index++) {
Expression expressionToRewrite = sourceShuttledExpressions.get(index);
if (expressionToRewrite instanceof Literal) {
rewrittenExpressions.add(expressionToRewrite);
Expression expressionShuttledToRewrite = sourceShuttledExpressions.get(index);
if (expressionShuttledToRewrite instanceof Literal) {
rewrittenExpressions.add(expressionShuttledToRewrite);
continue;
}
final Set<Object> slotsToRewrite = expressionToRewrite.collectToSet(
expression -> expression instanceof Slot);
Expression replacedExpression = ExpressionUtils.replace(expressionToRewrite,
final Set<Object> slotsToRewrite =
expressionShuttledToRewrite.collectToSet(expression -> expression instanceof Slot);
Expression replacedExpression = ExpressionUtils.replace(expressionShuttledToRewrite,
targetToTargetReplacementMapping);
if (replacedExpression.anyMatch(slotsToRewrite::contains)) {
// if contains any slot to rewrite, which means can not be rewritten by target, bail out
return ImmutableList.of();
}
Expression sourceExpression = sourceExpressionsToWrite.get(index);
if (sourceExpression instanceof NamedExpression
&& replacedExpression.nullable() != sourceExpression.nullable()) {
// if enable join eliminate, query maybe inner join and mv maybe outer join.
// If the slot is at null generate side, the nullable maybe different between query and view
// So need to force to consistent.
replacedExpression = sourceExpression.nullable() ?
new Nullable(replacedExpression) : new NonNullable(replacedExpression);
}
if (sourceExpression instanceof NamedExpression) {
NamedExpression sourceNamedExpression = (NamedExpression) sourceExpression;
replacedExpression = new Alias(sourceNamedExpression.getExprId(), replacedExpression,
Expand Down Expand Up @@ -358,30 +369,92 @@ protected Expression rewriteExpression(Expression sourceExpressionsToWrite, Plan
* For another example as following:
* predicate a = b in mv, and a = b and c = d in query, the compensatory predicate is c = d
*/
protected SplitPredicate predicatesCompensate(StructInfo queryStructInfo, StructInfo viewStructInfo,
SlotMapping queryToViewSlotMapping) {
protected SplitPredicate predicatesCompensate(
StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping queryToViewSlotMapping,
ComparisonResult comparisonResult
) {
// viewEquivalenceClass to query based
SlotMapping viewToQuerySlotMapping = queryToViewSlotMapping.inverse();
final Set<Expression> equalCompensateConjunctions = compensateEquivalence(
queryStructInfo,
viewStructInfo,
viewToQuerySlotMapping,
comparisonResult);
// range compensate
final Set<Expression> rangeCompensatePredicates = compensateRangePredicate(
queryStructInfo,
viewStructInfo,
viewToQuerySlotMapping,
comparisonResult);
// residual compensate
final Set<Expression> residualCompensatePredicates = compensateResidualPredicate(
queryStructInfo,
viewStructInfo,
viewToQuerySlotMapping,
comparisonResult);
// if the join type in query and mv plan is different, we should check and add filter on mv to make
// the mv join type is accord with query
Set<Set<Slot>> viewNoNullableSlot = comparisonResult.getViewNoNullableSlot();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe requireNoNullableSlot name is better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, it's better

// extract
if (!viewNoNullableSlot.isEmpty()) {
Set<Expression> queryUsedRejectNullSlotsViewBased = ExpressionUtils.extractConjunction(
queryStructInfo.getSplitPredicate().getRangePredicate()).stream()
.map(expression -> {
if (TypeUtils.isNotNull(expression).isPresent()) {
return ImmutableList.of(TypeUtils.isNotNull(expression).get());
} else {
Set<Object> slotRefrenceSet =
expression.collectToSet(expr -> expr instanceof SlotReference);
if (slotRefrenceSet.size() != 1) {
return null;
}
return slotRefrenceSet.iterator().next();
}
})
.filter(Objects::nonNull)
.map(expr -> ExpressionUtils.replace((Expression) expr,
queryToViewSlotMapping.toSlotReferenceMap()))
.collect(Collectors.toSet());

if (viewNoNullableSlot.stream().anyMatch(
set -> Sets.intersection(set, queryUsedRejectNullSlotsViewBased).isEmpty())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be allMatch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty splitPredicate is invalid, so this should be anymatch

return SplitPredicate.empty();
}
Copy link
Contributor

@keanji-x keanji-x Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is empty splitPredicate meaning

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe an invalid instance is better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this means invalid, i will change it more semantic

}
return SplitPredicate.of(ExpressionUtils.and(equalCompensateConjunctions),
rangeCompensatePredicates.isEmpty() ? BooleanLiteral.of(true)
: ExpressionUtils.and(rangeCompensatePredicates),
residualCompensatePredicates.isEmpty() ? BooleanLiteral.of(true)
: ExpressionUtils.and(residualCompensatePredicates));
}

protected Set<Expression> compensateEquivalence(StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping viewToQuerySlotMapping,
ComparisonResult comparisonResult) {
EquivalenceClass queryEquivalenceClass = queryStructInfo.getEquivalenceClass();
EquivalenceClass viewEquivalenceClass = viewStructInfo.getEquivalenceClass();
// viewEquivalenceClass to query based
Map<SlotReference, SlotReference> viewToQuerySlotMapping = queryToViewSlotMapping.inverse()
.toSlotReferenceMap();
EquivalenceClass viewEquivalenceClassQueryBased = viewEquivalenceClass.permute(viewToQuerySlotMapping);
Map<SlotReference, SlotReference> viewToQuerySlotMap = viewToQuerySlotMapping.toSlotReferenceMap();
EquivalenceClass viewEquivalenceClassQueryBased = viewEquivalenceClass.permute(viewToQuerySlotMap);
if (viewEquivalenceClassQueryBased == null) {
return SplitPredicate.empty();
return ImmutableSet.of();
}
final List<Expression> equalCompensateConjunctions = new ArrayList<>();
final Set<Expression> equalCompensateConjunctions = new HashSet<>();
if (queryEquivalenceClass.isEmpty() && viewEquivalenceClass.isEmpty()) {
equalCompensateConjunctions.add(BooleanLiteral.of(true));
}
if (queryEquivalenceClass.isEmpty() && !viewEquivalenceClass.isEmpty()) {
return SplitPredicate.empty();
if (queryEquivalenceClass.isEmpty()
&& !viewEquivalenceClass.isEmpty()) {
return ImmutableSet.of();
}
EquivalenceClassSetMapping queryToViewEquivalenceMapping = EquivalenceClassSetMapping.generate(
queryEquivalenceClass, viewEquivalenceClassQueryBased);
EquivalenceClassSetMapping queryToViewEquivalenceMapping =
EquivalenceClassSetMapping.generate(queryEquivalenceClass, viewEquivalenceClassQueryBased);
// can not map all target equivalence class, can not compensate
if (queryToViewEquivalenceMapping.getEquivalenceClassSetMap().size()
< viewEquivalenceClass.getEquivalenceSetList().size()) {
return SplitPredicate.empty();
return ImmutableSet.of();
}
// do equal compensate
Set<Set<SlotReference>> mappedQueryEquivalenceSet =
Expand Down Expand Up @@ -410,49 +483,57 @@ protected SplitPredicate predicatesCompensate(StructInfo queryStructInfo, Struct
}
}
);
// TODO range predicates and residual predicates compensate, Simplify implementation.
return equalCompensateConjunctions;
}

protected Set<Expression> compensateResidualPredicate(StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping viewToQuerySlotMapping,
ComparisonResult comparisonResult) {
SplitPredicate querySplitPredicate = queryStructInfo.getSplitPredicate();
SplitPredicate viewSplitPredicate = viewStructInfo.getSplitPredicate();

// range compensate
List<Expression> rangeCompensate = new ArrayList<>();
Expression queryRangePredicate = querySplitPredicate.getRangePredicate();
Expression viewRangePredicate = viewSplitPredicate.getRangePredicate();
Expression viewRangePredicateQueryBased = ExpressionUtils.replace(viewRangePredicate, viewToQuerySlotMapping);

Set<Expression> queryRangeSet = Sets.newHashSet(ExpressionUtils.extractConjunction(queryRangePredicate));
Set<Expression> viewRangeQueryBasedSet = Sets.newHashSet(
ExpressionUtils.extractConjunction(viewRangePredicateQueryBased));
// query range predicate can not contain all view range predicate when view have range predicate, bail out
if (!viewRangePredicateQueryBased.equals(BooleanLiteral.TRUE) && !queryRangeSet.containsAll(
viewRangeQueryBasedSet)) {
return SplitPredicate.empty();
}
queryRangeSet.removeAll(viewRangeQueryBasedSet);
rangeCompensate.addAll(queryRangeSet);

// residual compensate
List<Expression> residualCompensate = new ArrayList<>();
Expression queryResidualPredicate = querySplitPredicate.getResidualPredicate();
Expression viewResidualPredicate = viewSplitPredicate.getResidualPredicate();
Expression viewResidualPredicateQueryBased =
ExpressionUtils.replace(viewResidualPredicate, viewToQuerySlotMapping);
ExpressionUtils.replace(viewResidualPredicate, viewToQuerySlotMapping.toSlotReferenceMap());
Set<Expression> queryResidualSet =
Sets.newHashSet(ExpressionUtils.extractConjunction(queryResidualPredicate));
Set<Expression> viewResidualQueryBasedSet =
Sets.newHashSet(ExpressionUtils.extractConjunction(viewResidualPredicateQueryBased));
// query residual predicate can not contain all view residual predicate when view have residual predicate,
// bail out
if (!viewResidualPredicateQueryBased.equals(BooleanLiteral.TRUE) && !queryResidualSet.containsAll(
viewResidualQueryBasedSet)) {
return SplitPredicate.empty();
if (!viewResidualPredicateQueryBased.equals(BooleanLiteral.TRUE)
&& !queryResidualSet.containsAll(viewResidualQueryBasedSet)) {
return ImmutableSet.of();
}
queryResidualSet.removeAll(viewResidualQueryBasedSet);
residualCompensate.addAll(queryResidualSet);
return queryResidualSet;
}

return SplitPredicate.of(ExpressionUtils.and(equalCompensateConjunctions),
rangeCompensate.isEmpty() ? BooleanLiteral.of(true) : ExpressionUtils.and(rangeCompensate),
residualCompensate.isEmpty() ? BooleanLiteral.of(true) : ExpressionUtils.and(residualCompensate));
protected Set<Expression> compensateRangePredicate(StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping viewToQuerySlotMapping,
ComparisonResult comparisonResult) {
// TODO range predicates and residual predicates compensate, Simplify implementation.
SplitPredicate querySplitPredicate = queryStructInfo.getSplitPredicate();
SplitPredicate viewSplitPredicate = viewStructInfo.getSplitPredicate();

Expression queryRangePredicate = querySplitPredicate.getRangePredicate();
Expression viewRangePredicate = viewSplitPredicate.getRangePredicate();
Expression viewRangePredicateQueryBased =
ExpressionUtils.replace(viewRangePredicate, viewToQuerySlotMapping.toSlotReferenceMap());

Set<Expression> queryRangeSet =
Sets.newHashSet(ExpressionUtils.extractConjunction(queryRangePredicate));
Set<Expression> viewRangeQueryBasedSet =
Sets.newHashSet(ExpressionUtils.extractConjunction(viewRangePredicateQueryBased));
// query range predicate can not contain all view range predicate when view have range predicate, bail out
if (!viewRangePredicateQueryBased.equals(BooleanLiteral.TRUE)
&& !queryRangeSet.containsAll(viewRangeQueryBasedSet)) {
return ImmutableSet.of();
}
queryRangeSet.removeAll(viewRangeQueryBasedSet);
return queryRangeSet;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public Void visitComparisonPredicate(ComparisonPredicate comparisonPredicate, Vo
equalPredicates.add(comparisonPredicate);
return null;
} else {
residualPredicates.add(comparisonPredicate);
rangePredicates.add(comparisonPredicate);
}
} else if ((leftArgOnlyContainsColumnRef && rightArg instanceof Literal)
|| (rightArgOnlyContainsColumnRef && leftArg instanceof Literal)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;

import java.util.List;

/**
* change nullable input col to non_nullable col
*/
Expand All @@ -39,4 +43,9 @@ public FunctionSignature customSignature() {
return FunctionSignature.ret(dataType).args(dataType);
}

@Override
public Expression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add check msg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

return new NonNullable(children.get(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;

import java.util.List;

/**
* change non_nullable input col to nullable col
*/
Expand All @@ -39,4 +43,9 @@ public FunctionSignature customSignature() {
return FunctionSignature.ret(dataType).args(dataType);
}

@Override
public Expression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add check msg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will fix it

return new Nullable(children.get(0));
}
}
22 changes: 22 additions & 0 deletions regression-test/data/nereids_rules_p0/mv/join/inner/inner_join.out
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@
6
6

-- !query1_5_before --
6
6

-- !query1_5_after --
6
6

-- !query2_0_before --
4
4
Expand Down Expand Up @@ -239,6 +247,20 @@
6
6

-- !query3_4_before --
1 1
1 1
1 1
1 1
1 1

-- !query3_4_after --
1 1
1 1
1 1
1 1
1 1

-- !query4_0_before --
4
4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ suite("aggregate_with_roll_up") {
sql "SET enable_fallback_to_original_planner=false"
sql "SET enable_materialized_view_rewrite=true"
sql "SET enable_nereids_timeout = false"
// tmp disable to rewrite, will be removed in the future
sql "SET disable_nereids_rules = 'ELIMINATE_OUTER_JOIN'"

sql """
drop table if exists orders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ suite("aggregate_without_roll_up") {
sql "SET enable_fallback_to_original_planner=false"
sql "SET enable_materialized_view_rewrite=true"
sql "SET enable_nereids_timeout = false"
// tmp disable to rewrite, will be removed in the future
sql "SET disable_nereids_rules = 'ELIMINATE_OUTER_JOIN'"
sql "SET global enable_auto_analyze = false"

sql """
drop table if exists orders
Expand Down Expand Up @@ -173,8 +170,8 @@ suite("aggregate_without_roll_up") {
}
}

// single table
// with filter
// // single table
// // with filter
def mv1_0 = "select o_shippriority, o_comment, " +
"sum(o_totalprice) as sum_total, " +
"max(o_totalprice) as max_total, " +
Expand Down
Loading