Skip to content

Commit

Permalink
[CALCITE-5483] ProjectAggregateMergeRule throws exception if literal …
Browse files Browse the repository at this point in the history
…is non-numeric

Close apache#3038

Co-authored-by: Bruce Irschick <brucei@bitquilltech.com>
  • Loading branch information
libenchao and Bruce Irschick committed Feb 9, 2023
1 parent b64cb13 commit 1a54261
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,21 @@ && kindCount(project.getProjects(), SqlKind.CASE) == 0) {
final RexInputRef ref0 = (RexInputRef) isNotNull.operands.get(0);
final RexCall cast = (RexCall) operands.get(1);
final RexInputRef ref1 = (RexInputRef) cast.operands.get(0);
if (ref0.getIndex() != ref1.getIndex()) {
break;
}
final int aggCallIndex = ref1.getIndex() - aggregate.getGroupCount();
if (aggCallIndex < 0) {
break;
}
final AggregateCall aggCall = aggregate.getAggCallList().get(aggCallIndex);
if (aggCall.getAggregation().getKind() != SqlKind.SUM) {
break;
}
final RexLiteral literal = (RexLiteral) operands.get(2);
if (ref0.getIndex() == ref1.getIndex()
&& Objects.equals(literal.getValueAs(BigDecimal.class), BigDecimal.ZERO)) {
final int aggCallIndex =
ref1.getIndex() - aggregate.getGroupCount();
if (aggCallIndex >= 0) {
final AggregateCall aggCall =
aggregate.getAggCallList().get(aggCallIndex);
if (aggCall.getAggregation().getKind() == SqlKind.SUM) {
int j =
findSum0(cluster.getTypeFactory(), aggCall, aggCallList);
return cluster.getRexBuilder().makeInputRef(call.type, j);
}
}
if (Objects.equals(literal.getValueAs(BigDecimal.class), BigDecimal.ZERO)) {
int j = findSum0(cluster.getTypeFactory(), aggCall, aggCallList);
return cluster.getRexBuilder().makeInputRef(call.type, j);
}
}
break;
Expand Down
42 changes: 42 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlLibrary;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.test.SqlTestFactory;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
Expand Down Expand Up @@ -5241,6 +5242,47 @@ private HepProgram getTransitiveProgram() {
.check();
}

/** Tests that ProjectAggregateMergeRule does nothing with non-numeric literals
* and does not throw an exception. */
@Test void testProjectAggregateMergeNonNumericLiteral() {
// Requires a NULLABLE column to trigger
final SqlTestFactory.CatalogReaderFactory catalogReaderFactory = (typeFactory, caseSensitive) ->
new MockCatalogReader(typeFactory, caseSensitive) {
@Override public MockCatalogReader init() {
MockSchema schema = new MockSchema("SALES");
registerSchema(schema);
final boolean nullable = true;
final RelDataType timestampType = typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.TIMESTAMP),
nullable);
String tableName = "NULLABLE";
MockTable table = MockTable
.create(this, schema, tableName, false, 100);
table.addColumn("HIREDATE", timestampType);
registerTable(table);
return this;
}
}.init();
final String sql = "select hiredate, coalesce(hiredate, {ts '1969-12-31 00:00:00'}) as c1\n"
+ "from sales.nullable\n"
+ "group by hiredate";
sql(sql)
.withCatalogReaderFactory(catalogReaderFactory)
.withRule(CoreRules.PROJECT_AGGREGATE_MERGE)
.checkUnchanged();
}

@Test void testProjectAggregateMergeNoOpForNonSum() {
final String sql = "select coalesce(m, 0)\n"
+ "from (\n"
+ " select max(deptno) as m\n"
+ " from sales.emp\n"
+ ")";
sql(sql)
.withRule(CoreRules.PROJECT_AGGREGATE_MERGE)
.checkUnchanged();
}

/**
* Test case for AggregateMergeRule, should merge 2 aggregates
* into a single aggregate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6329,6 +6329,37 @@ LogicalProject(EXPR$0=[+(+($1, $3), $2)])
LogicalAggregate(group=[{0, 1}], MS=[MIN($2)], SS=[SUM($2)])
LogicalProject(JOB=[$2], DEPTNO=[$7], SAL=[$5])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testProjectAggregateMergeNoOpForNonSum">
<Resource name="sql">
<![CDATA[select coalesce(m, 0)
from (
select max(deptno) as m
from sales.emp
)]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EXPR$0=[CASE(IS NOT NULL($0), CAST($0):INTEGER NOT NULL, 0)])
LogicalAggregate(group=[{}], M=[MAX($0)])
LogicalProject(DEPTNO=[$7])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testProjectAggregateMergeNonNumericLiteral">
<Resource name="sql">
<![CDATA[select hiredate, coalesce(hiredate, {ts '1969-12-31 00:00:00'}) as c1
from sales.nullable
group by hiredate]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(HIREDATE=[$0], C1=[CASE(IS NOT NULL($0), CAST($0):TIMESTAMP(0) NOT NULL, 1969-12-31 00:00:00)])
LogicalAggregate(group=[{0}])
LogicalTableScan(table=[[CATALOG, SALES, NULLABLE]])
]]>
</Resource>
</TestCase>
Expand Down

0 comments on commit 1a54261

Please sign in to comment.