Skip to content

Commit

Permalink
[BugFix] fix left outer join to inner join bug and string not equal r…
Browse files Browse the repository at this point in the history
…ewrite bug for 2.5 (backport #39331) (#40687)

Signed-off-by: ABingHuang <codekhuang@163.com>
  • Loading branch information
ABingHuang authored Feb 5, 2024
1 parent 4d099bd commit 469275d
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ private void ruleRewriteIterative(OptExpression tree, TaskContext rootTaskContex
List<Rule> rules = rootTaskContext.getOptimizerContext().getRuleSet().getRewriteRulesByType(ruleSetType);
context.getTaskScheduler().pushTask(new RewriteTreeTask(rootTaskContext, tree, rules, false));
context.getTaskScheduler().executeTasks(rootTaskContext);
if (ruleSetType.equals(RuleSetType.PUSH_DOWN_PREDICATE)) {
context.reset();
}
}

private void ruleRewriteIterative(OptExpression tree, TaskContext rootTaskContext, Rule rule) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.dump.DumpInfo;
import com.starrocks.sql.optimizer.operator.scalar.IsNullPredicateOperator;
import com.starrocks.sql.optimizer.rule.RuleSet;
import com.starrocks.sql.optimizer.task.SeriallyTaskScheduler;
import com.starrocks.sql.optimizer.task.TaskContext;
Expand All @@ -31,6 +32,10 @@ public class OptimizerContext {
private OptimizerConfig optimizerConfig;
private List<MaterializationContext> candidateMvs;

// Is not null predicate can be derived from inner join or semi join,
// which should be kept to be used to convert outer join into inner join.
private List<IsNullPredicateOperator> pushdownNotNullPredicates = Lists.newArrayList();

@VisibleForTesting
public OptimizerContext(Memo memo, ColumnRefFactory columnRefFactory) {
this.memo = memo;
Expand Down Expand Up @@ -129,4 +134,17 @@ public List<MaterializationContext> getCandidateMvs() {
public void addCandidateMvs(MaterializationContext candidateMv) {
this.candidateMvs.add(candidateMv);
}

public List<IsNullPredicateOperator> getPushdownNotNullPredicates() {
return pushdownNotNullPredicates;
}

public void addPushdownNotNullPredicates(IsNullPredicateOperator notNullPredicate) {
pushdownNotNullPredicates.add(notNullPredicate);
}

// Should clear pushdownNotNullPredicates after each call of PUSH_DOWN_PREDICATE rule set
public void reset() {
pushdownNotNullPredicates.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.starrocks.analysis.JoinOperator;
import com.starrocks.sql.optimizer.JoinHelper;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
Expand Down Expand Up @@ -58,15 +59,19 @@ public class JoinPredicatePushdown {
private List<ScalarOperator> leftPushDown;
private List<ScalarOperator> rightPushDown;

private final OptimizerContext optimizerContext;

public JoinPredicatePushdown(
OptExpression joinOptExpression, boolean isOnPredicate, boolean directToChild,
ColumnRefFactory columnRefFactory) {
ColumnRefFactory columnRefFactory,
OptimizerContext optimizerContext) {
this.joinOptExpression = joinOptExpression;
this.isOnPredicate = isOnPredicate;
this.directToChild = directToChild;
this.columnRefFactory = columnRefFactory;
this.leftPushDown = Lists.newArrayList();
this.rightPushDown = Lists.newArrayList();
this.optimizerContext = optimizerContext;
}

public OptExpression pushdown(ScalarOperator predicateToPush) {
Expand Down Expand Up @@ -323,11 +328,23 @@ private void deriveIsNotNullPredicate(

LogicalJoinOperator joinOp = ((LogicalJoinOperator) join.getOp());
JoinOperator joinType = joinOp.getJoinType();
if ((joinType.isInnerJoin() || joinType.isRightSemiJoin()) && leftPushDown.isEmpty()) {
leftEQ.stream().map(c -> new IsNullPredicateOperator(true, c.clone(), true)).forEach(leftPushDown::add);
boolean isLeftEmpty = leftPushDown.isEmpty();
if (joinType.isInnerJoin() || joinType.isRightSemiJoin()) {
leftEQ.stream().map(c -> new IsNullPredicateOperator(true, c.clone(), true)).forEach(notNull -> {
optimizerContext.addPushdownNotNullPredicates(notNull);
if (isLeftEmpty) {
leftPushDown.add(notNull);
}
});
}
if ((joinType.isInnerJoin() || joinType.isLeftSemiJoin()) && rightPushDown.isEmpty()) {
rightEQ.stream().map(c -> new IsNullPredicateOperator(true, c.clone(), true)).forEach(rightPushDown::add);
boolean isRightEmpty = rightPushDown.isEmpty();
if (joinType.isInnerJoin() || joinType.isLeftSemiJoin()) {
rightEQ.stream().map(c -> new IsNullPredicateOperator(true, c.clone(), true)).forEach(notNull -> {
optimizerContext.addPushdownNotNullPredicates(notNull);
if (isRightEmpty) {
rightPushDown.add(notNull);
}
});
}
joinOp.setHasDeriveIsNotNullPredicate(true);
}
Expand Down Expand Up @@ -512,15 +529,17 @@ private OptExpression convertOuterToInner(OptExpression joinOpt, ScalarOperator
Set<ColumnRefOperator> rightOutputColumnOps = columnRefFactory.getColumnRefs(rightColumns);

if (join.getJoinType().isLeftOuterJoin()) {
if (Utils.canEliminateNull(rightOutputColumnOps, predicateToPush)) {
if (Utils.canEliminateNull(rightOutputColumnOps, predicateToPush)
|| hasPushdownNotNull(rightOutputColumnOps, optimizerContext.getPushdownNotNullPredicates())) {
OptExpression newOpt = OptExpression.create(new LogicalJoinOperator.Builder().withOperator(join)
.setJoinType(JoinOperator.INNER_JOIN)
.build(),
joinOpt.getInputs());
return newOpt;
}
} else if (join.getJoinType().isRightOuterJoin()) {
if (Utils.canEliminateNull(leftOutputColumnOps, predicateToPush)) {
if (Utils.canEliminateNull(leftOutputColumnOps, predicateToPush)
|| hasPushdownNotNull(leftOutputColumnOps, optimizerContext.getPushdownNotNullPredicates())) {
OptExpression newOpt = OptExpression.create(new LogicalJoinOperator.Builder().withOperator(join)
.setJoinType(JoinOperator.INNER_JOIN)
.build(),
Expand Down Expand Up @@ -557,4 +576,8 @@ private OptExpression convertOuterToInner(OptExpression joinOpt, ScalarOperator
}
return joinOpt;
}

private boolean hasPushdownNotNull(Set<ColumnRefOperator> outputColumnOps, List<IsNullPredicateOperator> pushdownNotNulls) {
return pushdownNotNulls.stream().anyMatch(p -> outputColumnOps.containsAll(p.getColumnRefs()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public List<OptExpression> transform(OptExpression input, OptimizerContext conte
LogicalJoinOperator join = (LogicalJoinOperator) input.getOp();
ScalarOperator on = join.getOnPredicate();
JoinPredicatePushdown joinPredicatePushdown = new JoinPredicatePushdown(
input, true, false, context.getColumnRefFactory());
input, true, false, context.getColumnRefFactory(), context);
OptExpression root = joinPredicatePushdown.pushdown(join.getOnPredicate());
((LogicalJoinOperator) root.getOp()).setHasPushDownJoinOnClause(true);
if (root.getOp().equals(input.getOp()) && on.equals(join.getOnPredicate()) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public List<OptExpression> transform(OptExpression input, OptimizerContext conte
LogicalFilterOperator filter = (LogicalFilterOperator) input.getOp();
OptExpression joinOpt = input.getInputs().get(0);
JoinPredicatePushdown joinPredicatePushdown = new JoinPredicatePushdown(
joinOpt, false, false, context.getColumnRefFactory());
joinOpt, false, false, context.getColumnRefFactory(), context);
return Lists.newArrayList(joinPredicatePushdown.pushdown(filter.getPredicate()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1699,8 +1699,9 @@ private OptExpression pushdownPredicatesForJoin(OptExpression optExpression, Sca

private OptExpression doPushdownPredicate(OptExpression joinOptExpression, ScalarOperator predicate) {
Preconditions.checkState(joinOptExpression.getOp() instanceof LogicalJoinOperator);
optimizerContext.reset();
JoinPredicatePushdown joinPredicatePushdown = new JoinPredicatePushdown(joinOptExpression,
false, true, materializationContext.getQueryRefFactory());
false, true, materializationContext.getQueryRefFactory(), optimizerContext);
return joinPredicatePushdown.pushdown(predicate);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.collect.Range;
import com.google.common.collect.TreeRangeSet;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
Expand Down Expand Up @@ -67,14 +68,20 @@ public RangePredicate visitBinaryPredicate(
ScalarOperator right = predicate.getChild(1);
if (left.isColumnRef() && right.isConstantRef()) {
ConstantOperator constant = (ConstantOperator) right;
TreeRangeSet<ConstantOperator> rangeSet = TreeRangeSet.create();
rangeSet.addAll(range(predicate.getBinaryType(), constant));
return new ColumnRangePredicate(left.cast(), rangeSet);
TreeRangeSet<ConstantOperator> rangeSet = range(predicate.getBinaryType(), constant);
if (rangeSet == null) {
residualPredicates.add(predicate);
} else {
return new ColumnRangePredicate(left.cast(), rangeSet);
}
} else if (left.isConstantRef() && right.isColumnRef()) {
ConstantOperator constant = (ConstantOperator) left;
TreeRangeSet<ConstantOperator> rangeSet = TreeRangeSet.create();
rangeSet.addAll(range(predicate.getBinaryType(), constant));
return new ColumnRangePredicate(right.cast(), rangeSet);
TreeRangeSet<ConstantOperator> rangeSet = range(predicate.getBinaryType(), constant);
if (rangeSet == null) {
residualPredicates.add(predicate);
} else {
return new ColumnRangePredicate(right.cast(), rangeSet);
}
} else if (left.isColumnRef() && right.isColumnRef() && context.isAnd()) {
if (predicate.getBinaryType().isEqual()) {
columnEqualityPredicates.add(predicate);
Expand All @@ -84,14 +91,16 @@ public RangePredicate visitBinaryPredicate(
} else if (context.isAnd()) {
if (checkDateTrunc(left, right)) {
ConstantOperator constant = (ConstantOperator) right;
TreeRangeSet<ConstantOperator> rangeSet = TreeRangeSet.create();
rangeSet.addAll(range(predicate.getBinaryType(), constant));
return new ColumnRangePredicate(left.getChild(1).cast(), rangeSet);
TreeRangeSet<ConstantOperator> rangeSet = range(predicate.getBinaryType(), constant);
if (rangeSet != null) {
return new ColumnRangePredicate(left.getChild(1).cast(), rangeSet);
}
} else if (checkDateTrunc(right, left)) {
ConstantOperator constant = (ConstantOperator) left;
TreeRangeSet<ConstantOperator> rangeSet = TreeRangeSet.create();
rangeSet.addAll(range(predicate.getBinaryType(), constant));
return new ColumnRangePredicate(right.getChild(1).cast(), rangeSet);
TreeRangeSet<ConstantOperator> rangeSet = range(predicate.getBinaryType(), constant);
if (rangeSet != null) {
return new ColumnRangePredicate(right.getChild(1).cast(), rangeSet);
}
}
residualPredicates.add(predicate);
}
Expand Down Expand Up @@ -238,8 +247,8 @@ private Optional<RangePredicate> findColumnRangePredicate(
return rangePredicateOptional;
}

private static <C extends Comparable<C>> TreeRangeSet<C> range(BinaryPredicateOperator.BinaryType type, C value) {
TreeRangeSet<C> rangeSet = TreeRangeSet.create();
private static TreeRangeSet<ConstantOperator> range(BinaryPredicateOperator.BinaryType type, ConstantOperator value) {
TreeRangeSet<ConstantOperator> rangeSet = TreeRangeSet.create();
switch (type) {
case EQ:
rangeSet.add(Range.singleton(value));
Expand All @@ -257,6 +266,10 @@ private static <C extends Comparable<C>> TreeRangeSet<C> range(BinaryPredicateO
rangeSet.add(Range.lessThan(value));
return rangeSet;
case NE:
Type valueType = value.getType();
if (!valueType.isNumericType() && !valueType.isDateType()) {
return null;
}
rangeSet.add(Range.greaterThan(value));
rangeSet.add(Range.lessThan(value));
return rangeSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4558,4 +4558,28 @@ public void testStrColumnCountDistinctToBitmap() {
" count(distinct tag_id + 3) from user_tags;");
connectContext.getSessionVariable().setCboCteReuse(true);
}

@Test
public void testMultiLeftOuterJoin() {
{
String mv = "select t1a, t1b, v5, v8 " +
"from test.test_all_type left outer join test.t1 on t1d = v4 " +
"left outer join test.t2 on v5 = v7 where v9 = 10";
String query = "select t1a, t1b, v5, v8 " +
"from test.test_all_type left outer join test.t1 on t1d = v4 " +
"left outer join test.t2 on v5 = v7 where v9 = 10 and t1a != 'xxx'";
MVRewriteChecker checker = testRewriteOK(mv, query);
checker.contains("t1a != 'xxx'");
}

{
String mv = "select v1, v2, v3, v5, v8 " +
"from test.t0 left outer join test.t1 on v1 = v4 " +
"left outer join test.t2 on v5 = v7 where v9 = 10";
String query = "select v1, v2, v3, v5, v8 " +
"from test.t0 left outer join test.t1 on v1 = v4 " +
"left outer join test.t2 on v5 = v7 where v9 = 10 and v3 = 1";
testRewriteOK(mv, query);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ public void testPreprocessMvPartitionMv() throws Exception {
"PROPERTIES('replication_num' = '1');")
.withMaterializedView("create materialized view mv_3\n" +
"distributed by hash(k2) buckets 3\n" +
"refresh async\n" +
"refresh manual\n" +
"as select k2, sum(v1) as total from tbl_with_mv group by k2;")
.withMaterializedView("create materialized view mv_4\n" +
"PARTITION BY k1\n" +
Expand Down
Loading

0 comments on commit 469275d

Please sign in to comment.