Skip to content

Commit

Permalink
[BugFix] Support more operator types for text based mv rewrite in sub…
Browse files Browse the repository at this point in the history
…query (StarRocks#44674)

Signed-off-by: shuming.li <ming.moriarty@gmail.com>
  • Loading branch information
LiShuMing authored Apr 25, 2024
1 parent 09a4028 commit c46b7ce
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.starrocks.analysis.ParseNode;
Expand All @@ -45,6 +46,7 @@
import com.starrocks.sql.optimizer.operator.Operator;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalTopNOperator;
import com.starrocks.sql.optimizer.operator.pattern.Pattern;
Expand All @@ -71,6 +73,13 @@
public class TextMatchBasedRewriteRule extends Rule {
private static final Logger LOG = LogManager.getLogger(TextMatchBasedRewriteRule.class);

// Supported rewrite operator types in the sub-query to match with the specified operator types
public static final Set<OperatorType> SUPPORTED_REWRITE_OPERATOR_TYPES = ImmutableSet.of(
OperatorType.LOGICAL_PROJECT,
OperatorType.LOGICAL_UNION,
OperatorType.LOGICAL_LIMIT,
OperatorType.LOGICAL_FILTER
);
private final ConnectContext connectContext;
private final StatementBase stmt;
private final Map<Operator, ParseNode> optToAstMap;
Expand Down Expand Up @@ -379,34 +388,21 @@ private List<OptExpression> visitChildren(OptExpression optExpression, ConnectCo
return children;
}

@Override
public OptExpression visit(OptExpression optExpression, ConnectContext connectContext) {
List<OptExpression> children = visitChildren(optExpression, connectContext);
return OptExpression.create(optExpression.getOp(), children);
}

@Override
public OptExpression visitLogicalProject(OptExpression optExpression, ConnectContext connectContext) {
if (subQueryTextMatchCount++ > mvSubQueryTextMatchMaxCount) {
return optExpression;
}

OptExpression rewritten = doRewrite(optExpression);
if (rewritten != null) {
return rewritten;
}
List<OptExpression> children = visitChildren(optExpression, connectContext);
return OptExpression.create(optExpression.getOp(), children);
private boolean isReachLimit() {
return subQueryTextMatchCount++ > mvSubQueryTextMatchMaxCount;
}

@Override
public OptExpression visitLogicalUnion(OptExpression optExpression, ConnectContext connectContext) {
if (subQueryTextMatchCount++ > mvSubQueryTextMatchMaxCount) {
return optExpression;
}
OptExpression rewritten = doRewrite(optExpression);
if (rewritten != null) {
return rewritten;
public OptExpression visit(OptExpression optExpression, ConnectContext connectContext) {
LogicalOperator op = (LogicalOperator) optExpression.getOp();
if (SUPPORTED_REWRITE_OPERATOR_TYPES.contains(op.getOpType())) {
if (isReachLimit()) {
return optExpression;
}
OptExpression rewritten = doRewrite(optExpression);
if (rewritten != null) {
return rewritten;
}
}
List<OptExpression> children = visitChildren(optExpression, connectContext);
return OptExpression.create(optExpression.getOp(), children);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
import com.starrocks.sql.optimizer.operator.stream.LogicalBinlogScanOperator;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
import com.starrocks.sql.optimizer.rewrite.scalar.ReduceCastRule;
import com.starrocks.sql.optimizer.rule.transformation.materialization.rule.TextMatchBasedRewriteRule;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -265,7 +266,9 @@ public LogicalPlan visitQueryStatement(QueryStatement node, ExpressionMapping co

@Override
public LogicalPlan visitSelect(SelectRelation node, ExpressionMapping context) {
return new QueryTransformer(columnRefFactory, session, cteContext, inlineView, optToAstMap).plan(node, outer);
QueryTransformer queryTransformer = new QueryTransformer(columnRefFactory, session, cteContext, inlineView, optToAstMap);
LogicalPlan logicalPlan = queryTransformer.plan(node, outer);
return logicalPlan;
}

@Override
Expand Down Expand Up @@ -676,10 +679,9 @@ public LogicalPlan visitSubquery(SubqueryRelation node, ExpressionMapping contex

builder = addOrderByLimit(builder, node);

// store opt expression to ast map if sub-query is project or union
// store opt expression to ast map if sub-query's type is supported.
OperatorType operatorType = subQueryOptExpression.getOp().getOpType();
if (optToAstMap != null && (operatorType == OperatorType.LOGICAL_PROJECT ||
operatorType == OperatorType.LOGICAL_UNION)) {
if (optToAstMap != null && TextMatchBasedRewriteRule.SUPPORTED_REWRITE_OPERATOR_TYPES.contains(operatorType)) {
optToAstMap.put(subQueryOptExpression.getOp(), node.getQueryStatement());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,50 @@ public void testTextMatchRewriteWithUnionAll1() {
testRewriteOK(mv, query);
}
}

@Test
public void testTextMatchRewriteWithSubQuery1() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time";
String query = String.format("select user_id, count(time) from (%s) as t group by user_id limit 3;", mv);
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithSubQuery2() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time limit 3";
String query = String.format("select user_id, count(time) from (%s) as t group by user_id limit 3;", mv);
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithSubQuery3() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time order by " +
"user_id, time";
String query = String.format("select user_id, count(time) from (%s) as t group by user_id limit 3;", mv);
// TODO: support order by elimiation
testRewriteFail(mv, query);
}

@Test
public void testTextMatchRewriteWithSubQuery4() {
String mv = "select * from (select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id," +
" time order by user_id, time) s where user_id != 'xxxx'";
String query = String.format("select user_id, count(time) from (%s) as t group by user_id limit 3;", mv);
testRewriteOK(mv, query);
}

@Test
public void testTextMatchRewriteWithExtraOrder1() {
String mv = "select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags group by user_id, time";
String query = String.format("select user_id from (%s) t order by user_id, time;", mv);
testRewriteOK(mv, query);
}
@Test
public void testTextMatchRewriteWithExtraOrder2() {
String mv = "select user_id, count(1) from (select user_id, time, bitmap_union(to_bitmap(tag_id)) as a from user_tags " +
"group by user_id, time) t group by user_id";
String query = String.format("%s order by user_id;", mv);
// TODO: support text based view for more patterns, now only rewrite the same query and subquery
testRewriteFail(mv, query);
}
}

0 comments on commit c46b7ce

Please sign in to comment.