Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement](Nereids) Support to query rewrite by materialized view when join input has aggregate #30230

Merged
merged 7 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

package org.apache.doris.nereids.jobs.joinorder.hypergraph.node;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
Expand Down Expand Up @@ -57,24 +59,50 @@ public StructInfoNode(int index, Plan plan) {
}

private @Nullable List<Set<Expression>> collectExpressions(Plan plan) {
if (plan instanceof LeafPlan) {
return ImmutableList.of();
}
List<Set<Expression>> childExpressions = collectExpressions(plan.child(0));
if (!isValidNodePlan(plan) || childExpressions == null) {
return null;
}
if (plan instanceof LogicalAggregate) {
return ImmutableList.<Set<Expression>>builder()
.add(ImmutableSet.copyOf(plan.getExpressions()))
.add(ImmutableSet.copyOf(((LogicalAggregate<?>) plan).getGroupByExpressions()))
.addAll(childExpressions)
.build();
}
return ImmutableList.<Set<Expression>>builder()
.add(ImmutableSet.copyOf(plan.getExpressions()))
.addAll(childExpressions)
.build();

Pair<Boolean, Builder<Set<Expression>>> collector = Pair.of(true, ImmutableList.builder());
plan.accept(new DefaultPlanVisitor<Void, Pair<Boolean, ImmutableList.Builder<Set<Expression>>>>() {
@Override
public Void visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate,
Pair<Boolean, ImmutableList.Builder<Set<Expression>>> collector) {
if (!collector.key()) {
return null;
}
collector.value().add(ImmutableSet.copyOf(aggregate.getExpressions()));
collector.value().add(ImmutableSet.copyOf(((LogicalAggregate<?>) plan).getGroupByExpressions()));
return super.visit(aggregate, collector);
}

@Override
public Void visitLogicalFilter(LogicalFilter<? extends Plan> filter,
Pair<Boolean, ImmutableList.Builder<Set<Expression>>> collector) {
if (!collector.key()) {
return null;
}
collector.value().add(ImmutableSet.copyOf(filter.getExpressions()));
return super.visit(filter, collector);
}

@Override
public Void visitGroupPlan(GroupPlan groupPlan,
Pair<Boolean, ImmutableList.Builder<Set<Expression>>> collector) {
if (!collector.key()) {
return null;
}
Plan groupActualPlan = groupPlan.getGroup().getLogicalExpressions().get(0).getPlan();
return groupActualPlan.accept(this, collector);
}

@Override
public Void visit(Plan plan, Pair<Boolean, ImmutableList.Builder<Set<Expression>>> context) {
if (!isValidNodePlan(plan)) {
context.first = false;
return null;
}
return super.visit(plan, context);
}
}, collector);
return collector.key() ? collector.value().build() : null;
}

private boolean isValidNodePlan(Plan plan) {
Expand Down Expand Up @@ -104,7 +132,7 @@ public Set<CatalogRelation> getCatalogRelation() {

private static Plan extractPlan(Plan plan) {
if (plan instanceof GroupPlan) {
//TODO: Note mv can be in logicalExpression, how can we choose it
// TODO: Note mv can be in logicalExpression, how can we choose it
plan = ((GroupPlan) plan).getGroup().getLogicalExpressions().get(0)
.getPlan();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair = splitToTopPlanAndAggregate(viewStructInfo);
if (viewTopPlanAndAggPair == null) {
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("Split view to top plan and agg fail",
Pair.of("Split view to top plan and agg fail, view doesn't not contain aggregate",
String.format("view plan = %s\n", viewStructInfo.getOriginalPlan().treeString())));
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,13 @@ protected List<Plan> doRewrite(StructInfo queryStructInfo, CascadesContext casca
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("Predicate compensate fail",
String.format("query predicates = %s,\n query equivalenceClass = %s, \n"
+ "view predicates = %s,\n query equivalenceClass = %s\n",
+ "view predicates = %s,\n query equivalenceClass = %s\n"
+ "comparisonResult = %s ",
queryStructInfo.getPredicates(),
queryStructInfo.getEquivalenceClass(),
viewStructInfo.getPredicates(),
viewStructInfo.getEquivalenceClass())));
viewStructInfo.getEquivalenceClass(),
comparisonResult)));
continue;
}
Plan rewrittenPlan;
Expand Down Expand Up @@ -467,21 +469,22 @@ protected SplitPredicate predicatesCompensate(
Set<Set<Slot>> requireNoNullableViewSlot = comparisonResult.getViewNoNullableSlot();
// check query is use the null reject slot which view comparison need
if (!requireNoNullableViewSlot.isEmpty()) {
Set<Expression> queryPulledUpPredicates = queryStructInfo.getPredicates().getPulledUpPredicates();
Set<Expression> queryPulledUpPredicates = comparisonResult.getQueryAllPulledUpExpressions().stream()
.flatMap(expr -> ExpressionUtils.extractConjunction(expr).stream())
.collect(Collectors.toSet());
Set<Expression> nullRejectPredicates = ExpressionUtils.inferNotNull(queryPulledUpPredicates,
cascadesContext);
if (nullRejectPredicates.isEmpty() || queryPulledUpPredicates.containsAll(nullRejectPredicates)) {
// query has not null reject predicates, so return
return SplitPredicate.INVALID_INSTANCE;
}
SlotMapping queryToViewMapping = viewToQuerySlotMapping.inverse();
Set<Expression> queryUsedNeedRejectNullSlotsViewBased = nullRejectPredicates.stream()
.map(expression -> TypeUtils.isNotNull(expression).orElse(null))
.filter(Objects::nonNull)
.map(expr -> ExpressionUtils.replace((Expression) expr, queryToViewMapping.toSlotReferenceMap()))
.collect(Collectors.toSet());
if (requireNoNullableViewSlot.stream().anyMatch(
set -> Sets.intersection(set, queryUsedNeedRejectNullSlotsViewBased).isEmpty())) {
// query pulledUp predicates should have null reject predicates and contains any require noNullable slot
boolean valid = !queryPulledUpPredicates.containsAll(nullRejectPredicates)
&& requireNoNullableViewSlot.stream().noneMatch(
set -> Sets.intersection(set, queryUsedNeedRejectNullSlotsViewBased).isEmpty());
if (!valid) {
return SplitPredicate.INVALID_INSTANCE;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,23 @@ public class ComparisonResult {
private final boolean valid;
private final List<Expression> viewExpressions;
private final List<Expression> queryExpressions;
private final List<Expression> queryAllPulledUpExpressions;
private final Set<Set<Slot>> viewNoNullableSlot;
private final String errorMessage;

ComparisonResult(List<Expression> queryExpressions, List<Expression> viewExpressions,
Set<Set<Slot>> viewNoNullableSlot, boolean valid, String message) {
ComparisonResult(List<Expression> queryExpressions, List<Expression> queryAllPulledUpExpressions,
List<Expression> viewExpressions, Set<Set<Slot>> viewNoNullableSlot, boolean valid, String message) {
this.viewExpressions = ImmutableList.copyOf(viewExpressions);
this.queryExpressions = ImmutableList.copyOf(queryExpressions);
this.queryAllPulledUpExpressions = ImmutableList.copyOf(queryAllPulledUpExpressions);
this.viewNoNullableSlot = ImmutableSet.copyOf(viewNoNullableSlot);
this.valid = valid;
this.errorMessage = message;
}

public static ComparisonResult newInvalidResWithErrorMessage(String errorMessage) {
return new ComparisonResult(ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), false, errorMessage);
return new ComparisonResult(ImmutableList.of(), ImmutableList.of(), ImmutableList.of(),
ImmutableSet.of(), false, errorMessage);
}

public List<Expression> getViewExpressions() {
Expand All @@ -59,6 +62,10 @@ public List<Expression> getQueryExpressions() {
return queryExpressions;
}

public List<Expression> getQueryAllPulledUpExpressions() {
return queryAllPulledUpExpressions;
}

public Set<Set<Slot>> getViewNoNullableSlot() {
return viewNoNullableSlot;
}
Expand All @@ -78,6 +85,7 @@ public static class Builder {
ImmutableList.Builder<Expression> queryBuilder = new ImmutableList.Builder<>();
ImmutableList.Builder<Expression> viewBuilder = new ImmutableList.Builder<>();
ImmutableSet.Builder<Set<Slot>> viewNoNullableSlotBuilder = new ImmutableSet.Builder<>();
ImmutableList.Builder<Expression> queryAllPulledUpExpressionsBuilder = new ImmutableList.Builder<>();
boolean valid = true;

/**
Expand Down Expand Up @@ -108,25 +116,29 @@ public Builder addViewNoNullableSlot(Set<Slot> viewNoNullableSlot) {
return this;
}

public Builder addQueryAllPulledUpExpressions(Collection<? extends Expression> expressions) {
queryAllPulledUpExpressionsBuilder.addAll(expressions);
return this;
}

public boolean isInvalid() {
return !valid;
}

public ComparisonResult build() {
Preconditions.checkArgument(valid, "Comparison result must be valid");
return new ComparisonResult(queryBuilder.build(), viewBuilder.build(),
viewNoNullableSlotBuilder.build(), valid, "");
return new ComparisonResult(queryBuilder.build(), queryAllPulledUpExpressionsBuilder.build(),
viewBuilder.build(), viewNoNullableSlotBuilder.build(), valid, "");
}
}

@Override
public String toString() {
if (isInvalid()) {
return "INVALID";
}
return String.format("viewExpressions: %s \n "
+ "queryExpressions :%s \n "
+ "viewNoNullableSlot :%s \n",
viewExpressions, queryExpressions, viewNoNullableSlot);
return String.format("valid: %s \n "
+ "viewExpressions: %s \n "
+ "queryExpressions :%s \n "
+ "viewNoNullableSlot :%s \n"
+ "queryAllPulledUpExpressions :%s \n", valid, viewExpressions, queryExpressions,
viewNoNullableSlot, queryAllPulledUpExpressions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ private ComparisonResult buildComparisonRes() {
for (Pair<JoinType, Set<Slot>> inferredCond : inferredViewEdgeWithCond.values()) {
builder.addViewNoNullableSlot(inferredCond.second);
}
builder.addQueryAllPulledUpExpressions(
getQueryFilterEdges().stream()
.filter(this::canPullUp)
.flatMap(filter -> filter.getExpressions().stream()).collect(Collectors.toList()));
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,33 @@ private static boolean collectStructInfoFromGraph(HyperGraph hyperGraph,
}
// Collect relations from hyper graph which in the bottom plan
hyperGraph.getNodes().forEach(node -> {
// plan relation collector and set to map
StructInfoNode structInfoNode = (StructInfoNode) node;
// plan relation collector and set to map
Plan nodePlan = node.getPlan();
List<CatalogRelation> nodeRelations = new ArrayList<>();
nodePlan.accept(RELATION_COLLECTOR, nodeRelations);
relationBuilder.addAll(nodeRelations);
// every node should only have one relation, this is for LogicalCompatibilityContext
relationIdStructInfoNodeMap.put(nodeRelations.get(0).getRelationId(), (StructInfoNode) node);

// record expressions in node
if (structInfoNode.getExpressions() != null) {
structInfoNode.getExpressions().forEach(expression -> {
ExpressionLineageReplacer.ExpressionReplaceContext replaceContext =
new ExpressionLineageReplacer.ExpressionReplaceContext(
Lists.newArrayList(expression),
ImmutableSet.of(),
ImmutableSet.of());
topPlan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
// Replace expressions by expression map
List<Expression> replacedExpressions = replaceContext.getReplacedExpressions();
shuttledHashConjunctsToConjunctsMap.put(replacedExpressions.get(0), expression);
// Record this, will be used in top level expression shuttle later, see the method
// ExpressionLineageReplacer#visitGroupPlan
namedExprIdAndExprMapping.putAll(replaceContext.getExprIdExpressionMap());
});
}
});
// Collect expression from where in hyper graph
hyperGraph.getFilterEdges().forEach(filterEdge -> {
Expand Down Expand Up @@ -436,7 +456,9 @@ public Boolean visit(Plan plan, Set<JoinType> requiredJoinType) {
if (!(plan instanceof Filter)
&& !(plan instanceof Project)
&& !(plan instanceof CatalogRelation)
&& !(plan instanceof Join)) {
&& !(plan instanceof Join)
&& !(plan instanceof LogicalAggregate && !((LogicalAggregate) plan).getSourceRepeat()
.isPresent())) {
return false;
}
if (plan instanceof Join) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,11 @@
-- !query29_1_after --
0 178.10 1.20 8

-- !query30_0_before --
4 4 68 100.0000 36.5000
6 1 0 22.0000 57.2000

-- !query30_0_after --
4 4 68 100.0000 36.5000
6 1 0 22.0000 57.2000

Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,27 @@ d c 17.00 2
b a 39.00 6
d c 17.00 2

-- !query21_0_before --
4 4 68 100.0000 36.5000
6 1 0 22.0000 57.2000

-- !query21_0_after --
4 4 68 100.0000 36.5000
6 1 0 22.0000 57.2000

-- !query21_1_before --
4 4 92 100.0000 27.0000
6 1 0 22.0000 47.7000

-- !query21_1_after --
4 4 92 100.0000 27.0000
6 1 0 22.0000 47.7000

-- !query21_2_before --
4 4 68 100.0000 36.5000
6 1 0 22.0000 57.2000

-- !query21_2_after --
4 4 68 100.0000 36.5000
6 1 0 22.0000 57.2000

Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@
3 3 2023-12-11
4 3 2023-12-09

-- !query6_1_before --
2023-12-10 2023-12-10 2 4 3
2023-12-10 2023-12-10 2 4 3

-- !query6_1_after --
2023-12-10 2023-12-10 2 4 3
2023-12-10 2023-12-10 2 4 3

-- !query7_0_before --
3 3 2023-12-11

Expand Down
Loading
Loading