Skip to content

Commit

Permalink
[CALCITE-6261] AssertionError with field pruning & duplicate agg calls
Browse files Browse the repository at this point in the history
Signed-off-by: Niels Pardon <par@zurich.ibm.com>
  • Loading branch information
nielspardon authored and mihaibudiu committed Feb 26, 2024
1 parent 6ba3130 commit 022d878
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 21 deletions.
55 changes: 34 additions & 21 deletions core/src/main/java/org/apache/calcite/tools/RelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -2442,9 +2442,22 @@ private RelBuilder aggregate_(GroupKeyImpl groupKey,
assert groupSet.contains(set);
}

PairList<ImmutableSet<String>, RelDataTypeField> inFields = frame.fields;
final ImmutableBitSet groupSet2;
final ImmutableList<ImmutableBitSet> groupSets2;
return pruneAggregateInputFieldsAndDeduplicateAggCalls(r, groupSet, groupSets, aggregateCalls,
frame.fields, registrar.extraNodes);
}

/**
* Prunes unused fields on the input of the aggregate and removes duplicate aggregation calls.
*/
private RelBuilder pruneAggregateInputFieldsAndDeduplicateAggCalls(
RelNode r,
final ImmutableBitSet groupSet,
final ImmutableList<ImmutableBitSet> groupSets,
final List<AggregateCall> aggregateCalls,
PairList<ImmutableSet<String>, RelDataTypeField> inFields,
final List<RexNode> extraNodes) {
final ImmutableBitSet groupSetAfterPruning;
final ImmutableList<ImmutableBitSet> groupSetsAfterPruning;
if (config.pruneInputOfAggregate()
&& r instanceof Project) {
final Set<Integer> fieldsUsed =
Expand All @@ -2453,22 +2466,22 @@ private RelBuilder aggregate_(GroupKeyImpl groupKey,
// pretend that one field is used.
if (fieldsUsed.isEmpty()) {
r = ((Project) r).getInput();
groupSet2 = groupSet;
groupSets2 = groupSets;
groupSetAfterPruning = groupSet;
groupSetsAfterPruning = groupSets;
} else if (fieldsUsed.size() < r.getRowType().getFieldCount()) {
// Some fields are computed but not used. Prune them.
final Map<Integer, Integer> map = new HashMap<>();
final Map<Integer, Integer> sourceFieldToTargetFieldMap = new HashMap<>();
for (int source : fieldsUsed) {
map.put(source, map.size());
sourceFieldToTargetFieldMap.put(source, sourceFieldToTargetFieldMap.size());
}

groupSet2 = groupSet.permute(map);
groupSets2 =
groupSetAfterPruning = groupSet.permute(sourceFieldToTargetFieldMap);
groupSetsAfterPruning =
ImmutableBitSet.ORDERING.immutableSortedCopy(
ImmutableBitSet.permute(groupSets, map));
ImmutableBitSet.permute(groupSets, sourceFieldToTargetFieldMap));

final Mappings.TargetMapping targetMapping =
Mappings.target(map, r.getRowType().getFieldCount(),
Mappings.target(sourceFieldToTargetFieldMap, r.getRowType().getFieldCount(),
fieldsUsed.size());
final List<AggregateCall> oldAggregateCalls =
new ArrayList<>(aggregateCalls);
Expand All @@ -2493,24 +2506,24 @@ private RelBuilder aggregate_(GroupKeyImpl groupKey,
project.copy(cluster.traitSet(), project.getInput(), newProjects,
builder.build());
} else {
groupSet2 = groupSet;
groupSets2 = groupSets;
groupSetAfterPruning = groupSet;
groupSetsAfterPruning = groupSets;
}
} else {
groupSet2 = groupSet;
groupSets2 = groupSets;
groupSetAfterPruning = groupSet;
groupSetsAfterPruning = groupSets;
}

if (!config.dedupAggregateCalls() || Util.isDistinct(aggregateCalls)) {
return aggregate_(groupSet2, groupSets2, r, aggregateCalls,
registrar.extraNodes, inFields);
return aggregate_(groupSetAfterPruning, groupSetsAfterPruning, r, aggregateCalls,
extraNodes, inFields);
}

// There are duplicate aggregate calls. Rebuild the list to eliminate
// duplicates, then add a Project.
final Set<AggregateCall> callSet = new HashSet<>();
final PairList<Integer, @Nullable String> projects = PairList.of();
Util.range(groupSet.cardinality())
Util.range(groupSetAfterPruning.cardinality())
.forEach(i -> projects.add(i, null));
final List<AggregateCall> distinctAggregateCalls = new ArrayList<>();
for (AggregateCall aggregateCall : aggregateCalls) {
Expand All @@ -2522,10 +2535,10 @@ private RelBuilder aggregate_(GroupKeyImpl groupKey,
i = distinctAggregateCalls.indexOf(aggregateCall);
assert i >= 0;
}
projects.add(groupSet.cardinality() + i, aggregateCall.name);
projects.add(groupSetAfterPruning.cardinality() + i, aggregateCall.name);
}
aggregate_(groupSet, groupSets, r, distinctAggregateCalls,
registrar.extraNodes, inFields);
aggregate_(groupSetAfterPruning, groupSetsAfterPruning, r, distinctAggregateCalls,
extraNodes, inFields);
return project(projects.transform((i, name) -> aliasMaybe(field(i), name)));
}

Expand Down
61 changes: 61 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,67 @@ private RelNode buildRelWithDuplicateAggregates(
assertThat(root, hasTree(expected));
}

/**
* Test reproducing issue CALCITE-6261.
*/
@Test void testAggregateDuplicateAggCallsWithForceProjectAndFieldPruning() {
final Function<RelBuilder, RelNode> f1 = builder ->
// single table scan with force project of all columns
builder.scan("EMP")
.project(
ImmutableList.of(
builder.field("EMPNO"),
builder.field("ENAME"),
builder.field("JOB"),
builder.field("MGR"),
builder.field("HIREDATE"),
builder.field("SAL"),
builder.field("COMM"),
builder.field("DEPTNO")),
ImmutableList.of(),
true)
.aggregate(
builder.groupKey(builder.field("MGR")),
// duplicate avg() agg calls
builder.avg(false, "SALARY_AVG", builder.field("SAL")),
builder.sum(false, "SALARY_SUM", builder.field("SAL")),
builder.avg(false, "SALARY_MEAN", builder.field("SAL")))
.build();
final String expected = ""
+ "LogicalProject(MGR=[$0], SALARY_AVG=[$1], SALARY_SUM=[$2], SALARY_MEAN=[$1])\n"
+ " LogicalAggregate(group=[{0}], SALARY_AVG=[AVG($1)], SALARY_SUM=[SUM($1)])\n"
+ " LogicalProject(MGR=[$3], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(f1.apply(createBuilder()), hasTree(expected));
}

/**
* Test recreating the reproducer from issue CALCITE-5888 but with the existing scott tables.
*/
@Test void testAggregateDuplicateAggCallsAndFieldPruningWithJoinAndLiteralGroupKey() {
final Function<RelBuilder, RelNode> f1 = builder ->
// first inner join two tables
builder.scan("EMP").scan("DEPT")
.join(JoinRelType.INNER, "DEPTNO")
.aggregate(
// null group key
builder.groupKey(builder.cast(builder.literal(null), SqlTypeName.INTEGER)),
// duplicated min/max agg calls
builder.min(builder.field("SAL")),
builder.max(builder.field("SAL")),
builder.min(builder.field("SAL")),
builder.max(builder.field("SAL")))
.build();
final String expected = ""
+ "LogicalProject($f11=[$0], $f1=[$1], $f2=[$2], $f10=[$1], $f20=[$2])\n"
+ " LogicalAggregate(group=[{1}], agg#0=[MIN($0)], agg#1=[MAX($0)])\n"
+ " LogicalProject(SAL=[$5], $f11=[null:INTEGER])\n"
+ " LogicalJoin(condition=[=($7, $8)], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
assertThat(f1.apply(createBuilder()), hasTree(expected));
}

@Test void testAggregateFilter() {
// Equivalent SQL:
// SELECT deptno, COUNT(*) FILTER (WHERE empno > 100) AS c
Expand Down

0 comments on commit 022d878

Please sign in to comment.