Skip to content

Commit

Permalink
Fix NPE bug in materialize rewrite (StarRocks#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
Seaven authored Sep 29, 2021
1 parent a1a41cb commit d4d114a
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 225 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.starrocks.analysis.Expr;
import com.starrocks.catalog.AggregateType;
import com.starrocks.catalog.Column;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
Expand All @@ -14,22 +16,23 @@
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.starrocks.catalog.Function.CompareMode.IS_IDENTICAL;

// Rewrite project -> agg -> project -> scan logic operator by RewriteContext
// Currently, only used for percentile_union mv.
// we need rewrite percentile_approx to percentile_approx_raw(percentile_union)
// TODO(kks): Remove this class if we support percentile_union_count aggregate function later
public class MVProjectAggProjectScanRewrite extends MVAggRewrite {
private MVProjectAggProjectScanRewrite() {
}

public class MVProjectAggProjectScanRewrite {
private static final MVProjectAggProjectScanRewrite instance = new MVProjectAggProjectScanRewrite();

public static MVProjectAggProjectScanRewrite getInstance() {
Expand All @@ -50,30 +53,26 @@ public void rewriteOptExpressionTree(
input.inputAt(0).inputAt(0).inputAt(0).getOp() instanceof LogicalOlapScanOperator) {
LogicalProjectOperator topProject = (LogicalProjectOperator) input.getOp();
LogicalAggregationOperator agg = (LogicalAggregationOperator) input.inputAt(0).getOp();
LogicalProjectOperator bellowProject = (LogicalProjectOperator)
input.inputAt(0).inputAt(0).getOp();
LogicalOlapScanOperator scanOperator = (LogicalOlapScanOperator)
input.inputAt(0).inputAt(0).inputAt(0).getOp();
if (factory.getRelationId(scanOperator.getOutputColumns().get(0).getId()) == relationId) {
for (MaterializedViewRule.RewriteContext context : rewriteContexts) {
rewriteOlapScanOperator(input.inputAt(0).inputAt(0), scanOperator, context);
ColumnRefOperator projectColumn = rewriteProjectOperator(bellowProject,
context.queryColumnRef,
context.mvColumnRef);
rewriteAggOperator(agg, context.aggCall,
projectColumn,
context.mvColumn);
rewriteTopProjectOperator(agg, topProject,
projectColumn, context.aggCall);
}
LogicalProjectOperator bellowProject = (LogicalProjectOperator) input.inputAt(0).inputAt(0).getOp();
LogicalOlapScanOperator scanOperator =
(LogicalOlapScanOperator) input.inputAt(0).inputAt(0).inputAt(0).getOp();

if (factory.getRelationId(scanOperator.getOutputColumns().get(0).getId()) != relationId) {
return;
}

rewriteOlapScanOperator(input.inputAt(0).inputAt(0), scanOperator, rewriteContexts);
for (MaterializedViewRule.RewriteContext context : rewriteContexts) {
ColumnRefOperator projectColumn =
rewriteProjectOperator(bellowProject, context.queryColumnRef, context.mvColumnRef);
rewriteAggOperator(agg, context.aggCall, projectColumn, context.mvColumn);
rewriteTopProjectOperator(agg, topProject, projectColumn, context.aggCall);
}
}
}

private void rewriteTopProjectOperator(LogicalAggregationOperator agg,
LogicalProjectOperator project,
ColumnRefOperator aggUsedColumn,
CallOperator queryAgg) {
private void rewriteTopProjectOperator(LogicalAggregationOperator agg, LogicalProjectOperator project,
ColumnRefOperator aggUsedColumn, CallOperator queryAgg) {
ColumnRefOperator percentileColumn = null;
for (Map.Entry<ColumnRefOperator, CallOperator> kv : agg.getAggregations().entrySet()) {
if (kv.getValue().getFnName().equals(FunctionSet.PERCENTILE_UNION)
Expand Down Expand Up @@ -101,4 +100,80 @@ private void rewriteTopProjectOperator(LogicalAggregationOperator agg,
}
}
}

// Use mv column instead of query column
protected static void rewriteOlapScanOperator(OptExpression optExpression, LogicalOlapScanOperator olapScanOperator,
List<MaterializedViewRule.RewriteContext> rewriteContexts) {
List<ColumnRefOperator> outputColumns = new ArrayList<>(olapScanOperator.getOutputColumns());
Map<ColumnRefOperator, Column> columnRefOperatorColumnMap =
new HashMap<>(olapScanOperator.getColRefToColumnMetaMap());

for (MaterializedViewRule.RewriteContext rewriteContext : rewriteContexts) {
outputColumns.remove(rewriteContext.queryColumnRef);
outputColumns.add(rewriteContext.mvColumnRef);

columnRefOperatorColumnMap.remove(rewriteContext.queryColumnRef);
columnRefOperatorColumnMap.put(rewriteContext.mvColumnRef, rewriteContext.mvColumn);
}

LogicalOlapScanOperator newScanOperator = new LogicalOlapScanOperator(
olapScanOperator.getTable(),
outputColumns,
columnRefOperatorColumnMap,
olapScanOperator.getColumnMetaToColRefMap(),
olapScanOperator.getDistributionSpec(),
olapScanOperator.getLimit(),
olapScanOperator.getPredicate(),
olapScanOperator.getSelectedIndexId(),
olapScanOperator.getSelectedPartitionId(),
olapScanOperator.getPartitionNames(),
olapScanOperator.getSelectedTabletId(),
olapScanOperator.getHintsTabletIds());

optExpression.setChild(0, OptExpression.create(newScanOperator));
}

// Use mv column instead of query column
protected ColumnRefOperator rewriteProjectOperator(LogicalProjectOperator projectOperator,
ColumnRefOperator baseColumnRef,
ColumnRefOperator mvColumnRef) {
for (Map.Entry<ColumnRefOperator, ScalarOperator> kv : projectOperator.getColumnRefMap().entrySet()) {
if (kv.getValue().getUsedColumns().contains(baseColumnRef)) {
kv.setValue(mvColumnRef);
return kv.getKey();
}
}
Preconditions.checkState(false, "shouldn't reach here");
return null;
}

// TODO(kks): refactor this method later
// query: percentile_approx(a) && mv: percentile_union(a) -> percentile_union(a)
protected void rewriteAggOperator(LogicalAggregationOperator aggOperator,
CallOperator agg,
ColumnRefOperator aggUsedColumn,
Column mvColumn) {
for (Map.Entry<ColumnRefOperator, CallOperator> kv : aggOperator.getAggregations().entrySet()) {
String functionName = kv.getValue().getFnName();
if (functionName.equals(agg.getFnName()) &&
kv.getValue().getUsedColumns().getFirstId() == aggUsedColumn.getId()) {
if (functionName.equals(FunctionSet.PERCENTILE_APPROX) &&
mvColumn.getAggregationType() == AggregateType.PERCENTILE_UNION) {
kv.setValue(getPercentileFunction(kv.getValue()));
break;
}
}
}
}

private CallOperator getPercentileFunction(CallOperator oldAgg) {
Function fn = Expr.getBuiltinFunction(FunctionSet.PERCENTILE_UNION,
new Type[] {Type.PERCENTILE}, IS_IDENTICAL);
ScalarOperator child = oldAgg.getChildren().get(0);
if (child instanceof CastOperator) {
child = child.getChild(0);
}
Preconditions.checkState(child.isColumnRef());
return new CallOperator(FunctionSet.PERCENTILE_UNION, oldAgg.getType(), Lists.newArrayList(child), fn);
}
}
Loading

0 comments on commit d4d114a

Please sign in to comment.