Skip to content

Commit

Permalink
[Bugfix] keep the aggregate expr order when rewrite aggregate operator (
Browse files Browse the repository at this point in the history
#7657)

we should keep the aggregate expr order when rewrite, because we will use
singleDistinctFunctionPos to decide call update or merge

(cherry picked from commit 3fadd41)
  • Loading branch information
stdpain authored and mergify[bot] committed Jun 23, 2022
1 parent d2306ea commit 64cf75b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ private PhysicalHashAggregateOperator rewriteAggOperator(PhysicalHashAggregateOp
DecodeContext context) {
Map<Integer, Integer> newStringToDicts = Maps.newHashMap();

Map<ColumnRefOperator, CallOperator> newAggMap = Maps.newHashMap(aggOperator.getAggregations());
final List<Map.Entry<ColumnRefOperator, CallOperator>> newAggMapEntry = Lists.newArrayList();

for (Map.Entry<ColumnRefOperator, CallOperator> kv : aggOperator.getAggregations().entrySet()) {
boolean canApplyDictDecodeOpt = (kv.getValue().getUsedColumns().cardinality() > 0) &&
(PhysicalHashAggregateOperator.couldApplyLowCardAggregateFunction.contains(
Expand All @@ -597,7 +598,7 @@ private PhysicalHashAggregateOperator rewriteAggOperator(PhysicalHashAggregateOp
Collections.singletonList(dictColumn), newFunction,
oldCall.isDistinct());
ColumnRefOperator outputColumn = kv.getKey();
newAggMap.put(outputColumn, newCall);
newAggMapEntry.add(Maps.immutableEntry(outputColumn, newCall));
} else if (context.stringColumnIdToDictColumnIds.containsKey(columnId)) {
Integer dictColumnId = context.stringColumnIdToDictColumnIds.get(columnId);
ColumnRefOperator dictColumn = context.columnRefFactory.getColumnRef(dictColumnId);
Expand Down Expand Up @@ -628,7 +629,6 @@ private PhysicalHashAggregateOperator rewriteAggOperator(PhysicalHashAggregateOp
ColumnRefOperator outputStringColumn = kv.getKey();
final ColumnRefOperator newDictColumn = context.columnRefFactory.create(
dictColumn.getName(), ID_TYPE, dictColumn.isNullable());
newAggMap.remove(outputStringColumn);
newStringToDicts.put(outputStringColumn.getId(), newDictColumn.getId());

for (Pair<Integer, ColumnDict> globalDict : context.globalDicts) {
Expand All @@ -647,10 +647,15 @@ private PhysicalHashAggregateOperator rewriteAggOperator(PhysicalHashAggregateOp
newArguments, newFunction,
oldCall.isDistinct());

newAggMap.put(outputColumn, newCall);
newAggMapEntry.add(Maps.immutableEntry(outputColumn, newCall));
} else {
newAggMapEntry.add(kv);
}
} else {
newAggMapEntry.add(kv);
}
}
Map<ColumnRefOperator, CallOperator> newAggMap = ImmutableMap.copyOf(newAggMapEntry);

List<ColumnRefOperator> newGroupBys = Lists.newArrayList();
for (ColumnRefOperator groupBy : aggOperator.getGroupBys()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,26 @@ public void testDecodeNodeRewrite10() throws Exception {

@Test
public void testDecodeNodeRewriteMultiCountDistinct() throws Exception {
String sql;
String plan;
connectContext.getSessionVariable().setNewPlanerAggStage(2);
String sql = "select count(distinct a),count(distinct b) from (" +
sql = "select count(distinct a),count(distinct b) from (" +
"select lower(upper(S_ADDRESS)) as a, upper(S_ADDRESS) as b, " +
"count(*) from supplier group by a,b) as t ";
String plan = getFragmentPlan(sql);
plan = getFragmentPlan(sql);
Assert.assertFalse(plan.contains("Decode"));
Assert.assertTrue(plan.contains("7:AGGREGATE (merge finalize)\n" +
" | output: multi_distinct_count(13: count), multi_distinct_count(12: count)"));
" | output: multi_distinct_count(12: count), multi_distinct_count(13: count)"));

sql = "select count(distinct S_ADDRESS), count(distinct S_COMMENT) from supplier;";
plan = getFragmentPlan(sql);
Assert.assertTrue(plan.contains(" multi_distinct_count(11: S_ADDRESS), " +
"multi_distinct_count(12: S_COMMENT)"));
connectContext.getSessionVariable().setNewPlanerAggStage(3);
sql = "select max(S_ADDRESS), count(distinct S_ADDRESS) from supplier group by S_ADDRESS;";
plan = getFragmentPlan(sql);
Assert.assertTrue(plan.contains(" 4:AGGREGATE (update finalize)\n" +
" | output: max(13: S_ADDRESS), count(11: S_ADDRESS)"));
connectContext.getSessionVariable().setNewPlanerAggStage(0);
}

Expand Down Expand Up @@ -872,7 +879,7 @@ public void testMultiMaxMin() throws Exception {
sql = "select min(distinct S_ADDRESS), max(S_ADDRESS) from supplier_nullable";
plan = getFragmentPlan(sql);
Assert.assertTrue(plan.contains(" 1:AGGREGATE (update serialize)\n" +
" | output: max(11: S_ADDRESS), min(11: S_ADDRESS)"));
" | output: min(11: S_ADDRESS), max(11: S_ADDRESS)"));
Assert.assertTrue(plan.contains(" 3:AGGREGATE (merge finalize)\n" +
" | output: min(12: S_ADDRESS), max(13: S_ADDRESS)"));
Assert.assertTrue(plan.contains(" 4:Decode\n" +
Expand Down Expand Up @@ -914,7 +921,6 @@ public void testAssignWrongNullableProperty() throws Exception {
public void testHavingAggFunctionOnConstant() throws Exception {
String sql = "select S_ADDRESS from supplier GROUP BY S_ADDRESS HAVING (cast(count(null) as string)) IN (\"\")";
String plan = getCostExplain(sql);
System.out.println("plan = " + plan);
Assert.assertTrue(plan.contains(" 1:AGGREGATE (update finalize)\n" +
" | aggregate: count[(NULL); args: BOOLEAN; result: BIGINT; args nullable: true; result nullable: false]\n" +
" | group by: [10: S_ADDRESS, INT, false]\n" +
Expand All @@ -929,9 +935,10 @@ public void testHavingAggFunctionOnConstant() throws Exception {
public void testDecodeWithLimit() throws Exception {
String sql = "select count(*), S_ADDRESS from supplier group by S_ADDRESS limit 10";
String plan = getFragmentPlan(sql);
assertContains(plan," 3:Decode\n" +
assertContains(plan, " 3:Decode\n" +
" | <dict id 10> : <string id 3>\n" +
" | limit: 10");;
" | limit: 10");
;
}

@Test
Expand Down

0 comments on commit 64cf75b

Please sign in to comment.