Skip to content

Commit

Permalink
[BugFix] Fix split scan rule reuse column id (StarRocks#51674)
Browse files Browse the repository at this point in the history
Signed-off-by: Seaven <seaven_7@qq.com>
  • Loading branch information
Seaven authored Oct 10, 2024
1 parent 930c97b commit e8bb706
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ private String debugString(String headlinePrefix, String detailPrefix, int limit
String childHeadlinePrefix = detailPrefix + "-> ";
String childDetailPrefix = detailPrefix + " ";
for (OptExpression input : inputs) {
sb.append('\n');
sb.append(input.debugString(childHeadlinePrefix, childDetailPrefix, limitLine));
}
return sb.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ public Builder withOperator(LogicalOlapScanOperator scanOperator) {
builder.hintsReplicaIds = scanOperator.hintsReplicaIds;
builder.prunedPartitionPredicates = scanOperator.prunedPartitionPredicates;
builder.usePkIndex = scanOperator.usePkIndex;
builder.fromSplitOR = scanOperator.fromSplitOR;
builder.vectorSearchOptions = scanOperator.vectorSearchOptions;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ public class FilterSelectivityEvaluator {

public static final double NON_SELECTIVITY = 100;

public static final int IN_CHILDREN_THRESHOLD = 1024;
public static int IN_CHILDREN_THRESHOLD = 1024;

private int unionNumLimit;
private final int unionNumLimit;

private ScalarOperator predicate;
private final ScalarOperator predicate;

private Statistics statistics;
private final Statistics statistics;

private boolean isDecomposePhase;
private final boolean isDecomposePhase;

public FilterSelectivityEvaluator(ScalarOperator predicate, Statistics statistics, boolean isDecomposePhase) {
this.predicate = predicate;
Expand Down Expand Up @@ -90,23 +90,15 @@ private ColumnFilter evaluateScalarOperator(ScalarOperator scalarOperator) {

private List<ColumnFilter> decomposeInPredicate(InPredicateOperator predicate) {
List<ColumnFilter> inFilters = Lists.newArrayList();
Set<ScalarOperator> inSet = predicate.getChildren().stream().skip(1).collect(Collectors.toSet());
List<ScalarOperator> inList = predicate.getChildren().stream().skip(1).distinct().collect(Collectors.toList());
ColumnRefOperator column = (ColumnRefOperator) predicate.getChild(0);
int totalSize = inSet.size();
int totalSize = inList.size();
int numSubsets = (int) Math.ceil((double) totalSize / IN_CHILDREN_THRESHOLD);
List<List<ScalarOperator>> smallInSets = Lists.newArrayList();
for (int i = 0; i < numSubsets; i++) {
smallInSets.add(Lists.newArrayList(column));
}

int currentIndex = 0;
for (ScalarOperator element : inSet) {
int subsetIndex = currentIndex / IN_CHILDREN_THRESHOLD;
smallInSets.get(subsetIndex).add(element);
currentIndex++;
}
for (int i = 0; i < numSubsets; i++) {
InPredicateOperator newInPredicate = new InPredicateOperator(false, smallInSets.get(i));
List<ScalarOperator> s = Lists.newArrayListWithExpectedSize(IN_CHILDREN_THRESHOLD + 2);
s.add(column);
s.addAll(inList.subList(i * IN_CHILDREN_THRESHOLD, Math.min((i + 1) * IN_CHILDREN_THRESHOLD, totalSize)));
InPredicateOperator newInPredicate = new InPredicateOperator(false, s);
inFilters.add(evaluateScalarOperator(newInPredicate));
}

Expand Down Expand Up @@ -315,13 +307,12 @@ private double adjustNDV(double ndv) {

public static class ColumnFilter implements Comparable<ColumnFilter> {

private Double selectRatio;
private final Double selectRatio;

// TODO add index info
private final Optional<ColumnRefOperator> column;

private Optional<ColumnRefOperator> column;

private ScalarOperator filter;
private final ScalarOperator filter;

public ColumnFilter(double selectRatio, ScalarOperator filter) {
this.selectRatio = selectRatio;
Expand All @@ -343,6 +334,10 @@ public ScalarOperator getFilter() {
return filter;
}

public Optional<ColumnRefOperator> getColumn() {
return column;
}

public boolean isUnknownSelectRatio() {
return selectRatio > 1 && selectRatio < NON_SELECTIVITY;
}
Expand Down Expand Up @@ -376,13 +371,7 @@ public int hashCode() {

@Override
public int compareTo(@NotNull ColumnFilter o) {
if (selectRatio < o.selectRatio) {
return -1;
} else if (selectRatio > o.selectRatio) {
return 1;
} else {
return 0;
}
return selectRatio.compareTo(o.selectRatio);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,26 @@

package com.starrocks.sql.optimizer.rule.transformation;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.starrocks.catalog.Column;
import com.starrocks.common.Pair;
import com.starrocks.qe.ConnectContext;
import com.starrocks.qe.SessionVariable;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.Projection;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalUnionOperator;
import com.starrocks.sql.optimizer.operator.pattern.Pattern;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.InPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import com.starrocks.sql.optimizer.rewrite.scalar.FilterSelectivityEvaluator;
import com.starrocks.sql.optimizer.rewrite.scalar.FilterSelectivityEvaluator.ColumnFilter;
import com.starrocks.sql.optimizer.rewrite.scalar.NegateFilterShuttle;
Expand All @@ -38,6 +45,7 @@
import org.apache.logging.log4j.Logger;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.rewrite.scalar.FilterSelectivityEvaluator.NON_SELECTIVITY;
Expand Down Expand Up @@ -108,7 +116,7 @@ private List<OptExpression> transformImpl(OptExpression input, OptimizerContext
return Lists.newArrayList();
}

List<ColumnFilter> unknownSelectivityFilters = columnFilters.stream().filter(e -> e.isUnknownSelectRatio())
List<ColumnFilter> unknownSelectivityFilters = columnFilters.stream().filter(ColumnFilter::isUnknownSelectRatio)
.collect(Collectors.toList());
List<ColumnFilter> remainingFilters = columnFilters.stream().filter(e -> !e.isUnknownSelectRatio())
.collect(Collectors.toList());
Expand All @@ -127,8 +135,7 @@ private List<OptExpression> transformImpl(OptExpression input, OptimizerContext

List<ScalarOperator> newScanPredicates = rebuildScanPredicate(decomposeFilters, remainingFilters);

return Lists.newArrayList(OptExpression.create(buildUnionAllOperator(scan, newScanPredicates.size()),
buildUnionAllInputs(scan, newScanPredicates)));
return Lists.newArrayList(buildUnion(context.getColumnRefFactory(), scan, newScanPredicates));
}

private Pair<List<ColumnFilter>, List<ColumnFilter>> chooseRewriteColumnFilter(List<ColumnFilter> columnFilters,
Expand Down Expand Up @@ -178,39 +185,93 @@ private List<ScalarOperator> rebuildScanPredicate(List<ColumnFilter> decomposeFi
ScalarOperator remainingPredicate = Utils.compoundAnd(remainingFilters.stream().map(ColumnFilter::getFilter)
.collect(Collectors.toList()));
List<ScalarOperator> scanPredicates = Lists.newArrayList();
NegateFilterShuttle shuttle = NegateFilterShuttle.getInstance();
for (int i = 0; i < decomposeFilters.size(); i++) {
List<ScalarOperator> elements = Lists.newArrayList();
elements.add(decomposeFilters.get(i).getFilter());
List<ColumnFilter> subList = decomposeFilters.subList(0, i);
for (ColumnFilter columnFilter : subList) {
elements.add(shuttle.negateFilter(columnFilter.getFilter()));

boolean isSplitFromIn = true;
for (ColumnFilter decomposeFilter : decomposeFilters) {
isSplitFromIn &= decomposeFilter.getFilter() instanceof InPredicateOperator;
if (!isSplitFromIn) {
break;
}
isSplitFromIn = !((InPredicateOperator) decomposeFilter.getFilter()).isNotIn();
isSplitFromIn &= decomposeFilter.getColumn().isPresent();
}

if (isSplitFromIn && decomposeFilters.stream().map(ColumnFilter::getColumn).distinct().count() == 1) {
for (ColumnFilter decomposeFilter : decomposeFilters) {
List<ScalarOperator> elements = Lists.newArrayList();
elements.add(decomposeFilter.getFilter());
elements.add(remainingPredicate);
scanPredicates.add(Utils.compoundAnd(elements));
}
} else {
NegateFilterShuttle shuttle = NegateFilterShuttle.getInstance();
for (int i = 0; i < decomposeFilters.size(); i++) {
List<ScalarOperator> elements = Lists.newArrayList();
elements.add(decomposeFilters.get(i).getFilter());
List<ColumnFilter> subList = decomposeFilters.subList(0, i);
for (ColumnFilter columnFilter : subList) {
elements.add(shuttle.negateFilter(columnFilter.getFilter()));
}
elements.add(remainingPredicate);
scanPredicates.add(Utils.compoundAnd(elements));
}
elements.add(remainingPredicate);
scanPredicates.add(Utils.compoundAnd(elements));
}
return scanPredicates;
}

private LogicalUnionOperator buildUnionAllOperator(LogicalOlapScanOperator scanOperator, int childNum) {
List<ColumnRefOperator> outputColumns = scanOperator.getOutputColumns();
private OptExpression buildUnion(ColumnRefFactory factory, LogicalOlapScanOperator scan,
List<ScalarOperator> scanPredicates) {
List<ColumnRefOperator> outputColumns = scan.getOutputColumns();
List<List<ColumnRefOperator>> childOutputColumns = Lists.newArrayList();
for (int i = 0; i < childNum; i++) {
childOutputColumns.add(outputColumns);
List<OptExpression> inputs = Lists.newArrayList();
for (ScalarOperator scanPredicate : scanPredicates) {
Pair<OptExpression, List<ColumnRefOperator>> child =
buildUnionInputs(factory, scan, scanPredicate, outputColumns);
inputs.add(child.first);
childOutputColumns.add(child.second);
}
return new LogicalUnionOperator(outputColumns, childOutputColumns, true);

return OptExpression.create(new LogicalUnionOperator(outputColumns, childOutputColumns, true), inputs);
}

private List<OptExpression> buildUnionAllInputs(LogicalOlapScanOperator scanOperator,
List<ScalarOperator> scanPredicates) {
List<OptExpression> inputs = Lists.newArrayList();
for (ScalarOperator scanPredicate : scanPredicates) {
LogicalOlapScanOperator.Builder builder = new LogicalOlapScanOperator.Builder();
LogicalOlapScanOperator scan = builder.withOperator(scanOperator)
.setPredicate(scanPredicate).setFromSplitOR(true).build();
inputs.add(OptExpression.create(scan));
private Pair<OptExpression, List<ColumnRefOperator>> buildUnionInputs(ColumnRefFactory factory,
LogicalOlapScanOperator scan,
ScalarOperator scanPredicate,
List<ColumnRefOperator> outputs) {
Map<ColumnRefOperator, ColumnRefOperator> replaceRefs = Maps.newHashMap();
Map<Column, ColumnRefOperator> columnToRefs = Maps.newHashMap();
Map<ColumnRefOperator, Column> refToColumns = Maps.newHashMap();

scan.getColumnMetaToColRefMap().forEach((meta, ref) -> {
ColumnRefOperator newRef = factory.create(ref.getName(), ref.getType(), ref.isNullable());
columnToRefs.put(meta, newRef);
refToColumns.put(newRef, meta);
replaceRefs.put(ref, newRef);
});

ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(replaceRefs);
LogicalOlapScanOperator.Builder builder = LogicalOlapScanOperator.builder().withOperator(scan)
.setColRefToColumnMetaMap(refToColumns)
.setColumnMetaToColRefMap(columnToRefs)
.setFromSplitOR(true)
.setPredicate(rewriter.rewrite(scanPredicate));

if (scan.getProjection() != null) {
Map<ColumnRefOperator, ScalarOperator> newProjections = Maps.newHashMap();
scan.getProjection().getColumnRefMap().forEach((k, v) -> {
if (replaceRefs.containsKey(k)) {
Preconditions.checkState(k.equals(v));
newProjections.put(replaceRefs.get(k), replaceRefs.get(k));
} else {
ColumnRefOperator newRef = factory.create(k.getName(), k.getType(), k.isNullable());
newProjections.put(newRef, rewriter.rewrite(v));
replaceRefs.put(k, newRef);
}
});
builder.setProjection(new Projection(newProjections));
}
return inputs;
outputs = outputs.stream().map(replaceRefs::get).collect(Collectors.toList());
return Pair.create(OptExpression.create(builder.build()), outputs);
}

private boolean canBenefitFromSplit(double existSelectRatio, double splitMaxSelectRatio) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,15 +700,15 @@ public void testUnionAllWithTopNRuntimeFilter() throws Exception {

// tbl_mock_015
Assert.assertTrue(plan, plan.contains("probe runtime filters:\n" +
" - filter_id = 1, probe_expr = (1: mock_004)"));
" - filter_id = 4, probe_expr = (80: mock_004)"));
Assert.assertTrue(plan, plan.contains("probe runtime filters:\n" +
" - filter_id = 3, probe_expr = (24: mock_004)"));
" - filter_id = 3, probe_expr = (62: mock_004)"));

// table: tbl_mock_001, rollup: tbl_mock_001
Assert.assertTrue(plan, plan.contains("probe runtime filters:\n" +
" - filter_id = 0, probe_expr = (1: mock_004)"));
" - filter_id = 1, probe_expr = (116: mock_004)"));
Assert.assertTrue(plan, plan.contains("probe runtime filters:\n" +
" - filter_id = 4, probe_expr = (24: mock_004)"));
" - filter_id = 0, probe_expr = (98: mock_004)\n"));

}

Expand Down
Loading

0 comments on commit e8bb706

Please sign in to comment.