Skip to content

Commit

Permalink
[BugFix] Fix mv rewrite bug for count agg func with rollup (StarRocks…
Browse files Browse the repository at this point in the history
…#34319)

Signed-off-by: shuming.li <ming.moriarty@gmail.com>
  • Loading branch information
LiShuMing authored Nov 10, 2023
1 parent bc35e7b commit 859226e
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.starrocks.analysis.Expr;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.MaterializedView;
Expand All @@ -40,12 +41,14 @@
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -54,6 +57,7 @@
import java.util.Set;
import java.util.stream.Collectors;

import static com.starrocks.catalog.Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF;
import static com.starrocks.sql.optimizer.OptimizerTraceUtil.logMVRewrite;
import static com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorUtil.findArithmeticFunction;

Expand Down Expand Up @@ -377,16 +381,20 @@ private OptExpression rewriteForRollup(
newAggColumnRef, newAggColumnRef.getType(), newAggColumnRef.isNullable());
queryColumnRefToScalarMap.put(entry.getKey(), rewriteAggColumnRef);
}

// generate new agg exprs(rollup functions)
final Map<ColumnRefOperator, ScalarOperator> newProjection = new HashMap<>();
Map<ColumnRefOperator, CallOperator> newAggregations = rewriteAggregates(
queryAggregation, equationRewriter, rewriteContext.getOutputMapping(),
new ColumnRefSet(rewriteContext.getQueryColumnSet()), queryColumnRefToScalarMap);
new ColumnRefSet(rewriteContext.getQueryColumnSet()), queryColumnRefToScalarMap,
newProjection, !newQueryGroupKeys.isEmpty());
if (newAggregations == null) {
logMVRewrite(mvRewriteContext, "Rewrite rollup aggregate failed: cannot rewrite aggregate functions");
return null;
}

return createNewAggregate(rewriteContext, rewrittenQueryAggOp, newAggregations, queryColumnRefToScalarMap, mvOptExpr);
return createNewAggregate(rewriteContext, rewrittenQueryAggOp, newAggregations,
queryColumnRefToScalarMap, mvOptExpr, newProjection);
}

@Override
Expand Down Expand Up @@ -461,21 +469,24 @@ protected OptExpression createUnion(OptExpression queryInput, OptExpression view
aggregateMapping.put(entry.getKey(), mapped);
}

final Map<ColumnRefOperator, ScalarOperator> newProjection = new HashMap<>();
Map<ColumnRefOperator, CallOperator> newAggregations = rewriteAggregatesForUnion(
queryAgg.getAggregations(), columnMapping, aggregateMapping);
queryAgg.getAggregations(), columnMapping, aggregateMapping, newProjection, !originalGroupKeys.isEmpty());
if (newAggregations == null) {
logMVRewrite(mvRewriteContext, "Rewrite aggregate with union failed: rewrite aggregate for union failed");
return null;
}
return createNewAggregate(rewriteContext, queryAgg, newAggregations, aggregateMapping, unionExpr);
return createNewAggregate(rewriteContext, queryAgg, newAggregations, aggregateMapping,
unionExpr, newProjection);
}

private OptExpression createNewAggregate(
RewriteContext rewriteContext,
LogicalAggregationOperator queryAgg,
Map<ColumnRefOperator, CallOperator> newAggregations,
Map<ColumnRefOperator, ScalarOperator> queryColumnRefToScalarMap,
OptExpression mvOptExpr) {
OptExpression mvOptExpr,
Map<ColumnRefOperator, ScalarOperator> newProjection) {
// newGroupKeys may have duplicate because of EquivalenceClasses
// remove duplicate here as new grouping keys
List<ColumnRefOperator> originalGroupKeys = queryAgg.getGroupingKeys();
Expand Down Expand Up @@ -532,7 +543,6 @@ private OptExpression createNewAggregate(
}

// add projection to make sure that the output columns keep the same with the origin query
Map<ColumnRefOperator, ScalarOperator> newProjection = Maps.newHashMap();
if (queryAgg.getProjection() == null) {
for (int i = 0; i < originalGroupKeys.size(); i++) {
newProjection.put(originalGroupKeys.get(i), newGroupByKeyColumnRefs.get(i));
Expand Down Expand Up @@ -588,9 +598,12 @@ private Map<ColumnRefOperator, CallOperator> rewriteAggregates(Map<ColumnRefOper
EquationRewriter equationRewriter,
Map<ColumnRefOperator, ColumnRefOperator> mapping,
ColumnRefSet queryColumnSet,
Map<ColumnRefOperator, ScalarOperator> aggregateMapping) {
Map<ColumnRefOperator, CallOperator> newAggregations = Maps.newHashMap();
Map<ColumnRefOperator, ScalarOperator> aggregateMapping,
Map<ColumnRefOperator, ScalarOperator> newProjection,
boolean hasGroupByKeys) {
final Map<ColumnRefOperator, CallOperator> newAggregations = Maps.newHashMap();
equationRewriter.setOutputMapping(mapping);

for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : aggregates.entrySet()) {
Preconditions.checkState(entry.getValue() instanceof CallOperator);
CallOperator aggCall = (CallOperator) entry.getValue();
Expand All @@ -615,6 +628,7 @@ private Map<ColumnRefOperator, CallOperator> rewriteAggregates(Map<ColumnRefOper
}
ColumnRefOperator oldColRef = (ColumnRefOperator) aggregateMapping.get(entry.getKey());
newAggregations.put(oldColRef, newAggregate);
newProjection.put(oldColRef, genRollupProject(aggCall, oldColRef, hasGroupByKeys));
}

return newAggregations;
Expand All @@ -623,7 +637,9 @@ private Map<ColumnRefOperator, CallOperator> rewriteAggregates(Map<ColumnRefOper
private Map<ColumnRefOperator, CallOperator> rewriteAggregatesForUnion(
Map<ColumnRefOperator, CallOperator> aggregates,
Map<ColumnRefOperator, ColumnRefOperator> mapping,
Map<ColumnRefOperator, ScalarOperator> aggregateMapping) {
Map<ColumnRefOperator, ScalarOperator> aggregateMapping,
Map<ColumnRefOperator, ScalarOperator> newProjection,
boolean hasGroupByKeys) {
Map<ColumnRefOperator, CallOperator> rewrittens = Maps.newHashMap();
for (Map.Entry<ColumnRefOperator, CallOperator> entry : aggregates.entrySet()) {
Preconditions.checkState(entry.getValue() != null);
Expand All @@ -641,12 +657,28 @@ private Map<ColumnRefOperator, CallOperator> rewriteAggregatesForUnion(
aggCall.toString());
return null;
}
rewrittens.put((ColumnRefOperator) aggregateMapping.get(entry.getKey()), newAggregate);
ColumnRefOperator oldColRef = (ColumnRefOperator) aggregateMapping.get(entry.getKey());
rewrittens.put(oldColRef, newAggregate);
newProjection.put(oldColRef, genRollupProject(aggCall, oldColRef, hasGroupByKeys));
}

return rewrittens;
}

private ScalarOperator genRollupProject(CallOperator aggCall, ColumnRefOperator oldColRef, boolean hasGroupByKeys) {
if (!hasGroupByKeys && aggCall.getFnName().equals(FunctionSet.COUNT)) {
// NOTE: This can only happen when query has no group-by keys.
// The behavior is different between count(NULL) and sum(NULL), count(NULL) = 0, sum(NULL) = NULL.
// Add `coalesce(count_col, 0)` to avoid return NULL instead of 0 for count rollup.
List<ScalarOperator> args = Arrays.asList(oldColRef, ConstantOperator.createBigint(0L));
Type[] argTypes = args.stream().map(a -> a.getType()).toArray(Type[]::new);
return new CallOperator(FunctionSet.COALESCE, aggCall.getType(), args,
Expr.getBuiltinFunction(FunctionSet.COALESCE, argTypes, IS_NONSTRICT_SUPERTYPE_OF));
} else {
return oldColRef;
}
}

// generate new aggregates for rollup
// eg: count(col) -> sum(col)
private CallOperator getRollupAggregate(CallOperator aggCall, ColumnRefOperator targetColumn) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2694,6 +2694,17 @@ public void testRewriteAvg3() {
testRewriteOK(mv2, "select user_id, avg(tag_id % 10) from user_tags group by user_id;");
}

@Test
public void testCountWithRollup() {
String mv = "select user_id, count(tag_id) from user_tags group by user_id, time;";
testRewriteOK(mv, "select user_id, count(tag_id) from user_tags group by user_id, time;")
.notContain("coalesce");
testRewriteOK(mv, "select user_id, count(tag_id) from user_tags group by user_id;")
.notContain("coalesce");
testRewriteOK(mv, "select count(tag_id) from user_tags;")
.contains("coalesce");
}

@Test
public void testCountDistinctToBitmapCount1() {
String mv = "select user_id, bitmap_union(to_bitmap(tag_id)) from user_tags group by user_id;";
Expand Down
125 changes: 100 additions & 25 deletions test/sql/test_materialized_view/R/test_materialized_view_rewrite
Original file line number Diff line number Diff line change
Expand Up @@ -584,42 +584,46 @@ from emps right anti join depts using (deptno);
-- result:
2
-- !result


-- name: test_single_table_mv_rewrite
create table user_tags (time date, user_id int, user_name varchar(20), tag_id int) partition by range (time) (partition p1 values less than MAXVALUE) distributed by hash(time) buckets 3 properties('replication_num' = '1');
-- result:
-- !result
insert into user_tags values('2023-04-13', 1, 'a', 1);
-- result:
-- !result
insert into user_tags values('2023-04-13', 1, 'b', 2);
-- result:
-- !result
insert into user_tags values('2023-04-13', 1, 'c', 3);
-- result:
-- !result
insert into user_tags values('2023-04-13', 1, 'd', 4);
-- result:
-- !result
insert into user_tags values('2023-04-13', 1, 'e', 5);

-- result:
-- !result
create materialized view agg_count_mv1
distributed by hash(user_id)
as
select user_id, count(1) as cnt
from user_tags
group by user_id;

-- result:
-- !result
refresh materialized view agg_count_mv1 with sync mode;
analyze table agg_count_mv1 with sync mode;

create materialized view agg_count_mv2
distributed by hash(user_id)
as
select user_id, user_name, count(1) as cnt
from user_tags
group by user_id, user_name;

-- result:
-- !result
refresh materialized view agg_count_mv2 with sync mode;
analyze table agg_count_mv2 with sync mode;

explain select user_id, count(1) as cnt
from user_tags
group by user_id;
-- result:
[REGEX]agg_count_mv1

CREATE TABLE `user_tags_2` (
`time` date NULL COMMENT "",
`user_id` int(11) NULL COMMENT "",
Expand All @@ -638,49 +642,120 @@ PROPERTIES (
"light_schema_change" = "true",
"compression" = "LZ4"
);

-- result:
-- !result
insert into user_tags_2 values('2023-04-13', 1, 'a', 1);
-- result:
-- !result
insert into user_tags_2 values('2023-04-13', 2, 'b', 2);
-- result:
-- !result
insert into user_tags_2 values('2023-04-13', 3, 'c', 3);
-- result:
-- !result
insert into user_tags_2 values('2023-04-13', 4, 'd', 4);
-- result:
-- !result
insert into user_tags_2 values('2023-04-13', 5, 'e', 5);

-- result:
-- !result
create materialized view agg_count_mv3
distributed by hash(user_id)
as
select user_id, count(1) as cnt
from user_tags_2
group by user_id;

-- result:
-- !result
refresh materialized view agg_count_mv3 with sync mode;

create materialized view agg_count_mv4
distributed by hash(user_id)
as
select user_id, user_name, count(1) as cnt
from user_tags_2
group by user_id, user_name;

-- result:
-- !result
refresh materialized view agg_count_mv4 with sync mode;


explain select user_id, count(1) as cnt
from user_tags_2
group by user_id;
-- result:
[REGEX]agg_count_mv3

create materialized view agg_count_mv5
distributed by hash(user_id)
as
select user_id, user_name, count(1) as cnt, sum(tag_id) as total
from user_tags_2
group by user_id, user_name;

-- result:
-- !result
refresh materialized view agg_count_mv5 with sync mode;

explain select user_id, user_name, count(1) as cnt
from user_tags_2
group by user_id, user_name;
-- name: test_count_rollup_with_empty_table
create table empty_tbl(time date, user_id int, user_name varchar(20), tag_id int) partition by range (time) (partition p1 values less than MAXVALUE) distributed by hash(time) buckets 3 properties('replication_num' = '1');
-- result:
-- !result
create materialized view empty_tbl_with_mv distributed by hash(user_id)
as select user_id, count(tag_id) from empty_tbl group by user_id, time;
-- result:
-- !result
select user_id, count(tag_id) from empty_tbl group by user_id, time;
-- result:
-- !result
select user_id, count(tag_id) from empty_tbl group by user_id;
-- result:
-- !result
drop table empty_tbl;
-- result:
-- !result
drop materialized view empty_tbl_with_mv;
-- result:
-- !result
CREATE TABLE orders (
dt date NOT NULL,
order_id bigint NOT NULL,
user_id int NOT NULL,
merchant_id int NOT NULL,
good_id int NOT NULL,
good_name string NOT NULL,
price int NOT NULL,
cnt int NOT NULL,
revenue int NOT NULL,
state tinyint NOT NULL
)
PRIMARY KEY (dt, order_id)
PARTITION BY RANGE(dt) (
PARTITION p20210820 VALUES [('2021-08-20'), ('2021-08-21')),
PARTITION p20210821 VALUES [('2021-08-21'), ('2021-08-22'))
)
DISTRIBUTED BY HASH(order_id) BUCKETS 4
PROPERTIES (
"replication_num" = "1",
"enable_persistent_index" = "true"
);
-- result:
-- !result

CREATE MATERIALIZED VIEW order_mv2
PARTITION BY date_trunc('MONTH', dt)
DISTRIBUTED BY HASH(order_id) BUCKETS 10
REFRESH ASYNC START('2023-07-01 10:00:00') EVERY (interval 1 day)
AS
select
dt,
order_id,
user_id,
sum(cnt) as total_cnt,
sum(revenue) as total_revenue,
count(state) as state_count
from orders group by dt, order_id, user_id;
-- result:
-- !result
select count() from orders;
-- result:
0
-- !result
drop materialized view order_mv2;
-- result:
[REGEX]agg_count_mv4
-- !result
Loading

0 comments on commit 859226e

Please sign in to comment.