Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CALCITE-5483] ProjectAggregateMergeRule throws exception if literal is non-numeric #3038

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,21 @@ && kindCount(project.getProjects(), SqlKind.CASE) == 0) {
final RexInputRef ref0 = (RexInputRef) isNotNull.operands.get(0);
final RexCall cast = (RexCall) operands.get(1);
final RexInputRef ref1 = (RexInputRef) cast.operands.get(0);
if (ref0.getIndex() != ref1.getIndex()) {
break;
}
final int aggCallIndex = ref1.getIndex() - aggregate.getGroupCount();
if (aggCallIndex < 0) {
break;
}
final AggregateCall aggCall = aggregate.getAggCallList().get(aggCallIndex);
if (aggCall.getAggregation().getKind() != SqlKind.SUM) {
break;
}
final RexLiteral literal = (RexLiteral) operands.get(2);
if (ref0.getIndex() == ref1.getIndex()
&& Objects.equals(literal.getValueAs(BigDecimal.class), BigDecimal.ZERO)) {
final int aggCallIndex =
ref1.getIndex() - aggregate.getGroupCount();
if (aggCallIndex >= 0) {
final AggregateCall aggCall =
aggregate.getAggCallList().get(aggCallIndex);
if (aggCall.getAggregation().getKind() == SqlKind.SUM) {
int j =
findSum0(cluster.getTypeFactory(), aggCall, aggCallList);
return cluster.getRexBuilder().makeInputRef(call.type, j);
}
}
if (Objects.equals(literal.getValueAs(BigDecimal.class), BigDecimal.ZERO)) {
int j = findSum0(cluster.getTypeFactory(), aggCall, aggCallList);
return cluster.getRexBuilder().makeInputRef(call.type, j);
}
}
break;
Expand Down
42 changes: 42 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlLibrary;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.test.SqlTestFactory;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
Expand Down Expand Up @@ -5241,6 +5242,47 @@ private HepProgram getTransitiveProgram() {
.check();
}

/** Tests that ProjectAggregateMergeRule does nothing with non-numeric literals
* and does not throw an exception. */
@Test void testProjectAggregateMergeNonNumericLiteral() {
// Requires a NULLABLE column to trigger
final SqlTestFactory.CatalogReaderFactory catalogReaderFactory = (typeFactory, caseSensitive) ->
new MockCatalogReader(typeFactory, caseSensitive) {
@Override public MockCatalogReader init() {
MockSchema schema = new MockSchema("SALES");
registerSchema(schema);
final boolean nullable = true;
final RelDataType timestampType = typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.TIMESTAMP),
nullable);
String tableName = "NULLABLE";
MockTable table = MockTable
.create(this, schema, tableName, false, 100);
table.addColumn("HIREDATE", timestampType);
registerTable(table);
return this;
}
}.init();
final String sql = "select hiredate, coalesce(hiredate, {ts '1969-12-31 00:00:00'}) as c1\n"
+ "from sales.nullable\n"
+ "group by hiredate";
sql(sql)
.withCatalogReaderFactory(catalogReaderFactory)
.withRule(CoreRules.PROJECT_AGGREGATE_MERGE)
.checkUnchanged();
}

@Test void testProjectAggregateMergeNoOpForNonSum() {
final String sql = "select coalesce(m, 0)\n"
+ "from (\n"
+ " select max(deptno) as m\n"
+ " from sales.emp\n"
+ ")";
sql(sql)
.withRule(CoreRules.PROJECT_AGGREGATE_MERGE)
.checkUnchanged();
}

/**
* Test case for AggregateMergeRule, should merge 2 aggregates
* into a single aggregate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6329,6 +6329,37 @@ LogicalProject(EXPR$0=[+(+($1, $3), $2)])
LogicalAggregate(group=[{0, 1}], MS=[MIN($2)], SS=[SUM($2)])
LogicalProject(JOB=[$2], DEPTNO=[$7], SAL=[$5])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testProjectAggregateMergeNoOpForNonSum">
<Resource name="sql">
<![CDATA[select coalesce(m, 0)
from (
select max(deptno) as m
from sales.emp
)]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EXPR$0=[CASE(IS NOT NULL($0), CAST($0):INTEGER NOT NULL, 0)])
LogicalAggregate(group=[{}], M=[MAX($0)])
LogicalProject(DEPTNO=[$7])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testProjectAggregateMergeNonNumericLiteral">
<Resource name="sql">
<![CDATA[select hiredate, coalesce(hiredate, {ts '1969-12-31 00:00:00'}) as c1
from sales.nullable
group by hiredate]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(HIREDATE=[$0], C1=[CASE(IS NOT NULL($0), CAST($0):TIMESTAMP(0) NOT NULL, 1969-12-31 00:00:00)])
LogicalAggregate(group=[{0}])
LogicalTableScan(table=[[CATALOG, SALES, NULLABLE]])
]]>
</Resource>
</TestCase>
Expand Down