Skip to content

Commit

Permalink
[CALCITE-6749] RelMdUtil#setAggChildKeys may return an incorrect result
Browse files Browse the repository at this point in the history
  • Loading branch information
rubenada committed Jan 8, 2025
1 parent c7724b4 commit c686e2e
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -568,15 +568,15 @@ public static void setAggChildKeys(
ImmutableBitSet groupKey,
Aggregate aggRel,
ImmutableBitSet.Builder childKey) {
List<AggregateCall> aggCalls = aggRel.getAggCallList();
final List<AggregateCall> aggCallList = aggRel.getAggCallList();
final List<Integer> groupList = aggRel.getGroupSet().asList();
for (int bit : groupKey) {
if (bit < aggRel.getGroupCount()) {
// group by column
childKey.set(bit);
childKey.set(groupList.get(bit));
} else {
// aggregate column -- set a bit for each argument being
// aggregated
AggregateCall agg = aggCalls.get(bit - aggRel.getGroupCount());
// aggregate column -- set a bit for each argument being aggregated
final AggregateCall agg = aggCallList.get(bit - aggRel.getGroupCount());
for (Integer arg : agg.getArgList()) {
childKey.set(arg);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@
*/
package org.apache.calcite.rel.metadata;

import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.test.RelMetadataFixture;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.util.ImmutableBitSet;

import org.junit.jupiter.api.Test;

Expand All @@ -30,6 +36,7 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

/**
Expand Down Expand Up @@ -110,4 +117,39 @@ final RelMetadataFixture sql(String sql) {
});
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6749">[CALCITE-6749]
* RelMdUtil#setAggChildKeys may return an incorrect result</a>. */
@Test void testSetAggChildKeys() {
Frameworks.withPlanner((cluster, relOptSchema, rootSchema) -> {
RelNode rel = sql("select d.deptno, count(distinct e.job)\n"
+ "from sales.emp e\n"
+ "right outer join sales.dept d on e.deptno = d.deptno\n"
+ "group by d.deptno")
.withRelTransform(relNode -> {
final HepProgramBuilder builder = HepProgram.builder();
builder.addRuleInstance(CoreRules.AGGREGATE_PROJECT_MERGE);
final HepPlanner prePlanner = new HepPlanner(builder.build());
prePlanner.setRoot(relNode);
return prePlanner.findBestExp();
}).toRel();
final Aggregate agg = (Aggregate) rel;
// We should get an Aggregate(group=[{9}], EXPR$1=[COUNT(DISTINCT $2)])
assertEquals(1, agg.getGroupCount());
assertEquals(9, agg.getGroupSet().asList().get(0));
assertEquals(1, agg.getAggCallList().size());
assertEquals(1, agg.getAggCallList().get(0).getArgList().size());
assertEquals(2, agg.getAggCallList().get(0).getArgList().get(0));
// The childKey corresponding to 0 (group key) must be 9
final ImmutableBitSet.Builder builder1 = ImmutableBitSet.builder();
RelMdUtil.setAggChildKeys(ImmutableBitSet.of(0), agg, builder1);
assertEquals(ImmutableBitSet.of(9), builder1.build());
// The childKey corresponding to 1 (count aggCall) must be 2
final ImmutableBitSet.Builder builder2 = ImmutableBitSet.builder();
RelMdUtil.setAggChildKeys(ImmutableBitSet.of(1), agg, builder2);
assertEquals(ImmutableBitSet.of(2), builder2.build());
return null;
});
}

}
6 changes: 3 additions & 3 deletions core/src/test/resources/sql/agg.iq
Original file line number Diff line number Diff line change
Expand Up @@ -2862,9 +2862,9 @@ select MGR, count(distinct DEPTNO, JOB), MIN(SAL), MAX(SAL) from "scott".emp gro

!ok

EnumerableAggregate(group=[{0}], EXPR$1=[COUNT($1, $2) FILTER $5], EXPR$2=[MIN($3) FILTER $6], EXPR$3=[MIN($4) FILTER $6])
EnumerableCalc(expr#0..5=[{inputs}], expr#6=[0], expr#7=[=($t5, $t6)], expr#8=[3], expr#9=[=($t5, $t8)], MGR=[$t1], DEPTNO=[$t2], JOB=[$t0], EXPR$2=[$t3], EXPR$3=[$t4], $g_0=[$t7], $g_3=[$t9])
EnumerableAggregate(group=[{2, 3, 7}], groups=[[{2, 3, 7}, {3}]], EXPR$2=[MIN($5)], EXPR$3=[MAX($5)], $g=[GROUPING($3, $7, $2)])
EnumerableAggregate(group=[{1}], EXPR$1=[COUNT($2, $0) FILTER $5], EXPR$2=[MIN($3) FILTER $6], EXPR$3=[MIN($4) FILTER $6])
EnumerableCalc(expr#0..5=[{inputs}], expr#6=[0], expr#7=[=($t5, $t6)], expr#8=[5], expr#9=[=($t5, $t8)], proj#0..4=[{exprs}], $g_0=[$t7], $g_5=[$t9])
EnumerableAggregate(group=[{2, 3, 7}], groups=[[{2, 3, 7}, {3}]], EXPR$2=[MIN($5)], EXPR$3=[MAX($5)], $g=[GROUPING($2, $3, $7)])
EnumerableTableScan(table=[[scott, EMP]])
!plan

Expand Down
55 changes: 29 additions & 26 deletions core/src/test/resources/sql/sub-query.iq
Original file line number Diff line number Diff line change
Expand Up @@ -427,35 +427,38 @@ where e.job not in (

!ok
EnumerableCalc(expr#0..9=[{inputs}], expr#10=[0], expr#11=[=($t5, $t10)], expr#12=[IS NULL($t1)], expr#13=[IS NOT NULL($t9)], expr#14=[<($t6, $t5)], expr#15=[OR($t12, $t13, $t14)], expr#16=[IS NOT TRUE($t15)], expr#17=[OR($t11, $t16)], EMPNO=[$t0], $condition=[$t17])
EnumerableHashJoin(condition=[AND(=($1, $7), =($2, $8))], joinType=[left])
EnumerableHashJoin(condition=[=($2, $4)], joinType=[left])
EnumerableCalc(expr#0..3=[{inputs}], EMPNO=[$t1], JOB=[$t2], DEPTNO=[$t3], DEPTNO0=[$t0])
EnumerableHashJoin(condition=[=($0, $3)], joinType=[inner])
EnumerableMergeJoin(condition=[AND(=($1, $7), =($2, $8))], joinType=[left])
EnumerableSort(sort0=[$1], sort1=[$2], dir0=[ASC], dir1=[ASC])
EnumerableMergeJoin(condition=[=($2, $4)], joinType=[left])
EnumerableMergeJoin(condition=[=($2, $3)], joinType=[inner])
EnumerableSort(sort0=[$2], dir0=[ASC])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
EnumerableTableScan(table=[[scott, DEPT]])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableAggregate(group=[{3}], c=[COUNT()], ck=[COUNT($1)])
EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableAggregate(group=[{1}])
EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
EnumerableSort(sort0=[$0], dir0=[ASC])
EnumerableAggregate(group=[{3}], c=[COUNT()], ck=[COUNT($1)])
EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
EnumerableTableScan(table=[[scott, DEPT]])
EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t0)], proj#0..2=[{exprs}], $condition=[$t3])
EnumerableAggregate(group=[{1, 3}], i=[LITERAL_AGG(true)])
EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableAggregate(group=[{1}])
EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
EnumerableTableScan(table=[[scott, DEPT]])
EnumerableAggregate(group=[{1}])
EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
EnumerableTableScan(table=[[scott, DEPT]])
EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[ASC])
EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t0)], proj#0..2=[{exprs}], $condition=[$t3])
EnumerableAggregate(group=[{1, 3}], i=[LITERAL_AGG(true)])
EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableAggregate(group=[{1}])
EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
EnumerableTableScan(table=[[scott, DEPT]])
!plan

# Condition that returns a NULL key.
Expand Down

0 comments on commit c686e2e

Please sign in to comment.