Skip to content

Commit

Permalink
[BugFix] Fix push down predicate on repeat node check (backport #47484)…
Browse files Browse the repository at this point in the history
… (#50200)

Co-authored-by: Seaven <seaven_7@qq.com>
  • Loading branch information
mergify[bot] and Seaven authored Aug 23, 2024
1 parent 9c9a2f9 commit 5cf6a54
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.logical.LogicalFilterOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalRepeatOperator;
Expand All @@ -31,13 +31,11 @@
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
import com.starrocks.sql.optimizer.rule.RuleType;
import org.apache.groovy.util.Maps;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static java.util.function.Function.identity;

public class PushDownPredicateRepeatRule extends TransformationRule {
public PushDownPredicateRepeatRule() {
Expand Down Expand Up @@ -89,20 +87,29 @@ public List<OptExpression> transform(OptExpression input, OptimizerContext conte
* it proves that the expression may contains null value, can not push down
*/
private boolean canPushDownPredicate(ScalarOperator predicate, Set<ColumnRefOperator> repeatColumns) {
Map<ColumnRefOperator, ScalarOperator> m =
repeatColumns.stream().map(col -> new ColumnRefOperator(col.getId(), Type.INVALID, "", true))
.collect(Collectors.toMap(identity(), col -> ConstantOperator.createNull(col.getType())));

ScalarOperator nullEval = new ReplaceColumnRefRewriter(m).rewrite(predicate);

ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter();
// The calculation of the null value is in the constant fold
nullEval = scalarRewriter.rewrite(nullEval, ScalarOperatorRewriter.DEFAULT_REWRITE_RULES);
if (nullEval.equals(ConstantOperator.createBoolean(true))) {
return false;
} else if (!(nullEval instanceof ConstantOperator)) {
ColumnRefSet usedRefs = predicate.getUsedColumns();
if (!usedRefs.containsAny(repeatColumns)) {
return false;
}

for (ColumnRefOperator repeatColumn : repeatColumns) {
if (!usedRefs.contains(repeatColumn)) {
continue;
}
// need check repeat column one by one
Map<ColumnRefOperator, ScalarOperator> m =
Maps.of(repeatColumn, ConstantOperator.createNull(repeatColumn.getType()));
ScalarOperator nullEval = new ReplaceColumnRefRewriter(m).rewrite(predicate);

ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter();
// The calculation of the null value is in the constant fold
nullEval = scalarRewriter.rewrite(nullEval, ScalarOperatorRewriter.DEFAULT_REWRITE_RULES);
if (nullEval.equals(ConstantOperator.createBoolean(true))) {
return false;
} else if (!(nullEval instanceof ConstantOperator)) {
return false;
}
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ public void testGroupByCube() throws Exception {
Assert.assertTrue(planFragment.contains("REPEAT_NODE"));
}

@Test
public void testGroupByRollup() throws Exception {
String sql = "select * from (select v1, v2, v3, grouping_id(v1, v3), grouping(v2) " +
"from t0 group by rollup(v1, v2, v3)) x where coalesce(v1, v2, v3) = 1;";
String planFragment = getFragmentPlan(sql);
assertNotContains(planFragment, "PREAGGREGATION: ON\n" +
" PREDICATES: coalesce(1: v1, 2: v2, 3: v3) = 1");
}

@Test
public void testPredicateOnRepeatNode() throws Exception {
FeConstants.runningUnitTest = true;
Expand All @@ -45,7 +54,7 @@ public void testPredicateOnRepeatNode() throws Exception {
Assert.assertTrue(plan.contains("1:REPEAT_NODE\n" +
" | repeat: repeat 2 lines [[], [1], [1, 2]]\n" +
" | PREDICATES: 1: v1 IS NOT NULL"));
Assert.assertTrue(plan.contains("0:OlapScanNode\n" +
Assert.assertTrue(plan, plan.contains("0:OlapScanNode\n" +
" TABLE: t0\n" +
" PREAGGREGATION: ON\n" +
" PREDICATES: 1: v1 IS NOT NULL"));
Expand Down

0 comments on commit 5cf6a54

Please sign in to comment.