Skip to content

Commit

Permalink
[feature](Nereids): add ExtractFilterFromJoin rule to support more (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener authored Jan 5, 2023
1 parent 5460c87 commit d36b937
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@
import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateOrderByConstant;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateUnnecessaryProject;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractFilterFromCrossJoin;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction;
import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.logical.InferPredicates;
import org.apache.doris.nereids.rules.rewrite.logical.InnerToCrossJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.logical.MergeSetOperations;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
Expand Down Expand Up @@ -86,6 +88,8 @@ public NereidsRewriteJobExecutor(CascadesContext cascadesContext) {
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
.add(visitorJob(RuleType.INFER_PREDICATES, new InferPredicates()))
.add(topDownBatch(ImmutableList.of(new ExtractFilterFromCrossJoin())))
.add(topDownBatch(ImmutableList.of(new MergeFilters())))
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public enum RuleType {
REWRITE_SORT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_HAVING_EXPRESSSION(RuleTypeClass.REWRITE),
REWRITE_REPEAT_EXPRESSSION(RuleTypeClass.REWRITE),
EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),
// Merge Consecutive plan
MERGE_PROJECTS(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite.logical;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;

import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Extract All condition From CrossJoin.
*/
public class ExtractFilterFromCrossJoin extends OneRewriteRuleFactory {
@Override
public Rule build() {
return crossLogicalJoin()
.then(join -> {
LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, join.getHint(),
join.left(), join.right());
Set<Expression> predicates = Stream.concat(join.getHashJoinConjuncts().stream(),
join.getOtherJoinConjuncts().stream())
.collect(Collectors.toSet());
return PlanUtils.filterOrSelf(predicates, newJoin);
}).toRule(RuleType.EXTRACT_FILTER_FROM_JOIN);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite.logical;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;

import org.junit.jupiter.api.Test;

class ExtractFilterFromCrossJoinTest implements PatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

@Test
void testExtract() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.CROSS_JOIN, Pair.of(0, 0))
.build();

PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new ExtractFilterFromCrossJoin())
.matches(
logicalFilter(
logicalJoin()
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ void testMultiJoinEliminateCross() {

@Test
@Disabled
// TODO: MultiJoin And EliminateOuter
void testEliminateBelowOuter() {
// FIXME: MultiJoin And EliminateOuter
String sql = "SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id";
PlanChecker.from(connectContext)
.analyze(sql)
Expand Down Expand Up @@ -120,4 +120,30 @@ void testOuterJoin() {
)
);
}

@Test
@Disabled
void testNoFilter() {
String sql = "Select * FROM T1 INNER JOIN T2 On true";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(
crossLogicalJoin()
);
}

@Test
void test() {
String sql = "select T1.score, T2.score from T1 inner join T2 on T1.id = T2.id where T1.score - 2 > T2.score";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(
logicalProject(
innerLogicalJoin()
)
);

}
}

0 comments on commit d36b937

Please sign in to comment.