diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 5e0fd69cbc9a05..53b6980d0be542 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -73,6 +73,7 @@ import org.apache.doris.nereids.rules.rewrite.InferPredicates; import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct; import org.apache.doris.nereids.rules.rewrite.LeadingJoin; +import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN; import org.apache.doris.nereids.rules.rewrite.MergeFilters; import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion; import org.apache.doris.nereids.rules.rewrite.MergeProjects; @@ -90,9 +91,9 @@ import org.apache.doris.nereids.rules.rewrite.PushProjectThroughUnion; import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject; import org.apache.doris.nereids.rules.rewrite.PushdownLimit; +import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughWindow; import org.apache.doris.nereids.rules.rewrite.ReorderJoin; -import org.apache.doris.nereids.rules.rewrite.ReplaceLimitNode; import org.apache.doris.nereids.rules.rewrite.RewriteCteChildren; import org.apache.doris.nereids.rules.rewrite.SemiJoinCommute; import org.apache.doris.nereids.rules.rewrite.SimplifyAggGroupBy; @@ -275,9 +276,10 @@ public class Rewriter extends AbstractBatchJobExecutor { // we should refactor like AggregateStrategies, e.g. LimitStrategies, // generate one PhysicalLimit if current distribution is gather or two // PhysicalLimits with gather exchange - new ReplaceLimitNode(), + new LimitSortToTopN(), new SplitLimit(), new PushdownLimit(), + new PushdownTopNThroughJoin(), new PushdownTopNThroughWindow(), new CreatePartitionTopNFromWindow() ) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 25d4db8abe6510..3e1582b9d3eb97 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -165,8 +165,6 @@ public enum RuleType { COLUMN_PRUNING(RuleTypeClass.REWRITE), ELIMINATE_SORT(RuleTypeClass.REWRITE), - PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW(RuleTypeClass.REWRITE), - PUSHDOWN_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE), PUSHDOWN_MIN_MAX_THROUGH_JOIN(RuleTypeClass.REWRITE), PUSHDOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE), PUSHDOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE), @@ -248,7 +246,12 @@ public enum RuleType { PUSH_LIMIT_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE), PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE), PUSH_LIMIT_THROUGH_WINDOW(RuleTypeClass.REWRITE), - PUSH_LIMIT_INTO_SORT(RuleTypeClass.REWRITE), + LIMIT_SORT_TO_TOP_N(RuleTypeClass.REWRITE), + // topN push down + PUSH_TOP_N_THROUGH_JOIN(RuleTypeClass.REWRITE), + PUSH_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE), + PUSH_TOP_N_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE), + PUSH_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE), // adjust nullable ADJUST_NULLABLE(RuleTypeClass.REWRITE), ADJUST_CONJUNCTS_RETURN_TYPE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateLimit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateLimit.java index 9cc19e47d8b5c6..8fbfc13934de57 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateLimit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateLimit.java @@ -19,18 +19,35 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.UnaryNode; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Eliminate limit = 0. */ -public class EliminateLimit extends OneRewriteRuleFactory { +public class EliminateLimit implements RewriteRuleFactory { + @Override - public Rule build() { - return logicalLimit() - .when(limit -> limit.getLimit() == 0) - .thenApply(ctx -> new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(), - ctx.root.getOutput())) - .toRule(RuleType.ELIMINATE_LIMIT); + public List buildRules() { + return ImmutableList.of( + logicalLimit() + .when(limit -> limit.getLimit() == 0) + .thenApply(ctx -> new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(), + ctx.root.getOutput())) + .toRule(RuleType.ELIMINATE_LIMIT), + logicalLimit(logicalOneRowRelation()) + .then(limit -> limit.getLimit() > 0 && limit.getOffset() == 0 + ? limit.child() : new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(), + limit.child().getOutput())) + .toRule(RuleType.ELIMINATE_LIMIT_ON_ONE_ROW_RELATION), + logicalLimit(logicalEmptyRelation()) + .then(UnaryNode::child) + .toRule(RuleType.ELIMINATE_LIMIT_ON_EMPTY_RELATION) + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReplaceLimitNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitSortToTopN.java similarity index 71% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReplaceLimitNode.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitSortToTopN.java index dd4479ae3a577c..fa9abca9e58f86 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ReplaceLimitNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitSortToTopN.java @@ -19,10 +19,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.UnaryNode; -import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; @@ -35,7 +32,7 @@ /** * rule to eliminate limit node by replace to other nodes. */ -public class ReplaceLimitNode implements RewriteRuleFactory { +public class LimitSortToTopN implements RewriteRuleFactory { @Override public List buildRules() { return ImmutableList.of( @@ -47,8 +44,8 @@ public List buildRules() { limit.getLimit(), limit.getOffset(), sort.child(0)); - }).toRule(RuleType.PUSH_LIMIT_INTO_SORT), - //limit->proj->sort ==> proj->topN + }).toRule(RuleType.LIMIT_SORT_TO_TOP_N), + // limit -> proj -> sort ==> proj -> topN logicalLimit(logicalProject(logicalSort())) .then(limit -> { LogicalProject> project = limit.child(); @@ -58,15 +55,7 @@ public List buildRules() { limit.getOffset(), sort.child(0)); return project.withChildren(Lists.newArrayList(topN)); - }).toRule(RuleType.PUSH_LIMIT_INTO_SORT), - logicalLimit(logicalOneRowRelation()) - .then(limit -> limit.getLimit() > 0 && limit.getOffset() == 0 - ? limit.child() : new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(), - limit.child().getOutput())) - .toRule(RuleType.ELIMINATE_LIMIT_ON_ONE_ROW_RELATION), - logicalLimit(logicalEmptyRelation()) - .then(UnaryNode::child) - .toRule(RuleType.ELIMINATE_LIMIT_ON_EMPTY_RELATION) + }).toRule(RuleType.LIMIT_SORT_TO_TOP_N) ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java index 386f4a01198a85..a0d64b1a6090fa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java @@ -41,27 +41,26 @@ public class PushdownFilterThroughProject implements RewriteRuleFactory { @Override public List buildRules() { return ImmutableList.of( - RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT.build(logicalFilter(logicalProject()) - .whenNot(filter -> filter.child().getProjects().stream().anyMatch( - expr -> expr.anyMatch(WindowExpression.class::isInstance))) - .then(PushdownFilterThroughProject::pushdownFilterThroughProject)), - // filter(project(limit)) will change to filter(limit(project)) by PushdownProjectThroughLimit, - // then we should change filter(limit(project)) to project(filter(limit)) - RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT - .build(logicalFilter(logicalLimit(logicalProject())) - .whenNot(filter -> filter.child().child().getProjects().stream() - .anyMatch(expr -> expr - .anyMatch(WindowExpression.class::isInstance))) - .then(filter -> { - LogicalLimit> limit = filter.child(); - LogicalProject project = limit.child(); + logicalFilter(logicalProject()) + .whenNot(filter -> filter.child().getProjects().stream().anyMatch( + expr -> expr.anyMatch(WindowExpression.class::isInstance))) + .then(PushdownFilterThroughProject::pushdownFilterThroughProject) + .toRule(RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT), + // filter(project(limit)) will change to filter(limit(project)) by PushdownProjectThroughLimit, + // then we should change filter(limit(project)) to project(filter(limit)) + logicalFilter(logicalLimit(logicalProject())) + .whenNot(filter -> filter.child().child().getProjects().stream() + .anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance))) + .then(filter -> { + LogicalLimit> limit = filter.child(); + LogicalProject project = limit.child(); - return project.withProjectsAndChild(project.getProjects(), - new LogicalFilter<>( - ExpressionUtils.replace(filter.getConjuncts(), - project.getAliasToProducer()), - limit.withChildren(project.child()))); - })) + return project.withProjectsAndChild(project.getProjects(), + new LogicalFilter<>( + ExpressionUtils.replace(filter.getConjuncts(), + project.getAliasToProducer()), + limit.withChildren(project.child()))); + }).toRule(RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownProjectThroughLimit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownProjectThroughLimit.java index 652d0309106943..8c4d2a93c5697b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownProjectThroughLimit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownProjectThroughLimit.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; /** + *
  * Before:
  *          project
  *             │
@@ -42,6 +43,7 @@
  *             │
  *             ▼
  *          plan node
+ * 
*/ public class PushdownProjectThroughLimit extends OneRewriteRuleFactory { @@ -50,9 +52,7 @@ public Rule build() { return logicalProject(logicalLimit()).thenApply(ctx -> { LogicalProject> logicalProject = ctx.root; LogicalLimit logicalLimit = logicalProject.child(); - return new LogicalLimit<>(logicalLimit.getLimit(), logicalLimit.getOffset(), - logicalLimit.getPhase(), logicalProject.withProjectsAndChild(logicalProject.getProjects(), - logicalLimit.child())); + return logicalLimit.withChildren(logicalProject.withChildren(logicalLimit.child())); }).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_LIMIT); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java new file mode 100644 index 00000000000000..8980664a9e1a3f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; +import org.apache.doris.nereids.util.Utils; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Push down TopN through Outer Join into left child ..... + */ +public class PushdownTopNThroughJoin implements RewriteRuleFactory { + + @Override + public List buildRules() { + return ImmutableList.of( + // topN -> join + logicalTopN(logicalJoin()) + // TODO: complex orderby + .when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr) + .allMatch(Slot.class::isInstance)) + .then(topN -> { + LogicalJoin join = topN.child(); + Plan newJoin = pushLimitThroughJoin(topN, join); + if (newJoin == null || topN.child().children().equals(newJoin.children())) { + return null; + } + return topN.withChildren(newJoin); + }) + .toRule(RuleType.PUSH_TOP_N_THROUGH_JOIN), + + // topN -> project -> join + logicalTopN(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots)) + // TODO: complex project + .when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr) + .allMatch(Slot.class::isInstance)) + .then(topN -> { + LogicalProject> project = topN.child(); + LogicalJoin join = project.child(); + + Plan newJoin = pushLimitThroughJoin(topN, join); + if (newJoin == null || join.children().equals(newJoin.children())) { + return null; + } + return topN.withChildren(project.withChildren(newJoin)); + }).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_JOIN) + ); + } + + private Plan pushLimitThroughJoin(LogicalTopN topN, LogicalJoin join) { + switch (join.getJoinType()) { + case LEFT_OUTER_JOIN: + Set rightOutputSet = join.right().getOutputSet(); + if (topN.getOrderKeys().stream().map(OrderKey::getExpr) + .anyMatch(e -> Utils.isIntersecting(rightOutputSet, e.getInputSlots()))) { + return null; + } + return join.withChildren(topN.withChildren(join.left()), join.right()); + case RIGHT_OUTER_JOIN: + Set leftOutputSet = join.left().getOutputSet(); + if (topN.getOrderKeys().stream().map(OrderKey::getExpr) + .anyMatch(e -> Utils.isIntersecting(leftOutputSet, e.getInputSlots()))) { + return null; + } + return join.withChildren(join.left(), topN.withChildren(join.right())); + case CROSS_JOIN: + List orderbySlots = topN.getOrderKeys().stream().map(OrderKey::getExpr) + .flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList()); + if (join.left().getOutputSet().containsAll(orderbySlots)) { + return join.withChildren(topN.withChildren(join.left()), join.right()); + } else if (join.right().getOutputSet().containsAll(orderbySlots)) { + return join.withChildren(join.left(), topN.withChildren(join.right())); + } else { + return null; + } + default: + // don't push limit. + return null; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughWindow.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughWindow.java index 755b71199cce48..f1547d910898e4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughWindow.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughWindow.java @@ -59,7 +59,7 @@ public List buildRules() { return topn; } return topn.withChildren(newWindow.get()); - }).toRule(RuleType.PUSHDOWN_TOP_N_THROUGH_WINDOW), + }).toRule(RuleType.PUSH_TOP_N_THROUGH_WINDOW), // topn -> projection -> window logicalTopN(logicalProject(logicalWindow())).then(topn -> { @@ -79,7 +79,7 @@ public List buildRules() { return topn; } return topn.withChildren(project.withChildren(newWindow.get())); - }).toRule(RuleType.PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW) + }).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_WINDOW) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index deb6eb983dd714..41c4f423045b83 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -351,10 +351,6 @@ public static List mergeArguments(Object... arguments) { return builder.build(); } - public static boolean isAllLiteral(Expression... children) { - return Arrays.stream(children).allMatch(c -> c instanceof Literal); - } - public static boolean isAllLiteral(List children) { return children.stream().allMatch(c -> c instanceof Literal); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownLimitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownLimitTest.java index f85882791e0b8b..28e1a7fa468619 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownLimitTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownLimitTest.java @@ -65,7 +65,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSupported { private final LogicalOlapScan scanScore = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.score); - private Plan scanStudent = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student); + private final LogicalOlapScan scanStudent = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student); @Override protected void runBeforeAll() throws Exception { @@ -114,7 +114,7 @@ protected void runBeforeAll() throws Exception { } @Test - public void testPushLimitThroughLeftJoin() { + void testPushLimitThroughLeftJoin() { test(JoinType.LEFT_OUTER_JOIN, true, logicalLimit( logicalProject( @@ -136,7 +136,7 @@ public void testPushLimitThroughLeftJoin() { } @Test - public void testPushLimitThroughRightJoin() { + void testPushLimitThroughRightJoin() { // after use RelationUtil to allocate relation id, the id will increase when getNextId() called. test(JoinType.RIGHT_OUTER_JOIN, true, logicalLimit( @@ -159,7 +159,7 @@ public void testPushLimitThroughRightJoin() { } @Test - public void testPushLimitThroughCrossJoin() { + void testPushLimitThroughCrossJoin() { test(JoinType.CROSS_JOIN, true, logicalLimit( logicalProject( @@ -181,7 +181,7 @@ public void testPushLimitThroughCrossJoin() { } @Test - public void testPushLimitThroughInnerJoin() { + void testPushLimitThroughInnerJoin() { test(JoinType.INNER_JOIN, true, logicalLimit( logicalProject( @@ -203,7 +203,7 @@ public void testPushLimitThroughInnerJoin() { } @Test - public void testTranslate() { + void testTranslate() { PlanChecker.from(connectContext).checkPlannerResult("select * from t1 left join t2 on t1.k1=t2.k1 limit 5", planner -> { List fragments = planner.getFragments(); @@ -227,7 +227,7 @@ public void testTranslate() { } @Test - public void testLimitPushSort() { + void testLimitPushSort() { PlanChecker.from(connectContext) .analyze("select k1 from t1 order by k1 limit 1") .rewrite() @@ -235,7 +235,7 @@ public void testLimitPushSort() { } @Test - public void testLimitPushUnion() { + void testLimitPushUnion() { PlanChecker.from(connectContext) .analyze("select k1 from t1 " + "union all select k2 from t2 " @@ -262,7 +262,7 @@ public void testLimitPushUnion() { } @Test - public void testLimitPushWindow() { + void testLimitPushWindow() { ConnectContext context = MemoTestUtils.createConnectContext(); context.getSessionVariable().setEnablePartitionTopN(true); NamedExpression grade = scanScore.getOutput().get(2).toSlot(); @@ -304,7 +304,7 @@ public void testLimitPushWindow() { } @Test - public void testTopNPushWindow() { + void testTopNPushWindow() { ConnectContext context = MemoTestUtils.createConnectContext(); context.getSessionVariable().setEnablePartitionTopN(true); NamedExpression grade = scanScore.getOutput().get(2).toSlot(); @@ -322,7 +322,7 @@ public void testTopNPushWindow() { List orderKey = ImmutableList.of( new OrderKey(windowAlias1.toSlot(), true, true) ); - LogicalSort sort = new LogicalSort<>(orderKey, window); + LogicalSort sort = new LogicalSort<>(orderKey, window); LogicalPlan plan = new LogicalPlanBuilder(sort) .limit(100) @@ -364,8 +364,8 @@ private Plan generatePlan(JoinType joinType, boolean hasProject) { LogicalJoin join = new LogicalJoin<>( joinType, joinConditions, - new LogicalOlapScan(((LogicalOlapScan) scanScore).getRelationId(), PlanConstructor.score), - new LogicalOlapScan(((LogicalOlapScan) scanStudent).getRelationId(), PlanConstructor.student) + new LogicalOlapScan(scanScore.getRelationId(), PlanConstructor.score), + new LogicalOlapScan(scanStudent.getRelationId(), PlanConstructor.student) ); if (hasProject) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoinTest.java new file mode 100644 index 00000000000000..b44ca08a2d1049 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoinTest.java @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.utframe.TestWithFeService; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class PushdownTopNThroughJoinTest extends TestWithFeService implements MemoPatternMatchSupported { + private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + + connectContext.setDatabase("default_cluster:test"); + + createTable("CREATE TABLE `t1` (\n" + + " `k1` int(11) NOT NULL,\n" + + " `k2` int(11) NOT NULL\n" + + ") ENGINE=OLAP\n" + + "COMMENT 'OLAP'\n" + + "DISTRIBUTED BY HASH(`k1`) BUCKETS 3\n" + + "PROPERTIES (\n" + + "\"replication_allocation\" = \"tag.location.default: 1\",\n" + + "\"in_memory\" = \"false\",\n" + + "\"storage_format\" = \"V2\",\n" + + "\"disable_auto_compaction\" = \"false\"\n" + + ");"); + + createTable("CREATE TABLE `t2` (\n" + + " `k1` int(11) NULL,\n" + + " `k2` int(11) NULL\n" + + ") ENGINE=OLAP\n" + + "COMMENT 'OLAP'\n" + + "DISTRIBUTED BY HASH(`k1`) BUCKETS 3\n" + + "PROPERTIES (\n" + + "\"replication_allocation\" = \"tag.location.default: 1\",\n" + + "\"in_memory\" = \"false\",\n" + + "\"storage_format\" = \"V2\",\n" + + "\"disable_auto_compaction\" = \"false\"\n" + + ");"); + } + + @Test + void testJoin() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) + .topN(10, 0, ImmutableList.of(0)) + .build(); + PlanChecker.from(connectContext, plan) + .applyTopDown(new PushdownTopNThroughJoin()) + .matches( + logicalTopN( + logicalJoin( + logicalTopN().when(l -> l.getLimit() == 10 && l.getOffset() == 0), + logicalOlapScan() + ) + ) + ); + } + + @Test + void testJoinSql() { + PlanChecker.from(connectContext) + .analyze("select * from t1 left join t2 on t1.k1 = t2.k1 order by t1.k1 limit 10") + .rewrite() + .matches( + logicalTopN( + logicalProject( + logicalJoin( + logicalTopN().when(l -> l.getLimit() == 10 && l.getOffset() == 0), + logicalOlapScan() + ) + ) + ) + ); + } + + @Test + void badCase() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) + .topN(10, 0, ImmutableList.of(0)) + .build(); + PlanChecker.from(connectContext, plan) + .applyTopDown(new PushdownTopNThroughJoin()) + .matches( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java index c5024ff9315953..65888232fe8105 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -142,6 +143,14 @@ public LogicalPlanBuilder limit(long limit) { return limit(limit, 0); } + public LogicalPlanBuilder topN(long limit, long offset, List orderKeySlotsIndex) { + List orderKeys = orderKeySlotsIndex.stream() + .map(i -> new OrderKey(this.plan.getOutput().get(i), false, false)) + .collect(Collectors.toList()); + LogicalTopN topNPlan = new LogicalTopN<>(orderKeys, limit, offset, this.plan); + return from(topNPlan); + } + public LogicalPlanBuilder filter(Expression conjunct) { return filter(ImmutableSet.copyOf(ExpressionUtils.extractConjunction(conjunct))); }