diff --git a/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcToEnumerableConverter.java b/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcToEnumerableConverter.java index 4d2961c572f..313a896ec1f 100644 --- a/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcToEnumerableConverter.java +++ b/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcToEnumerableConverter.java @@ -348,7 +348,7 @@ private SqlString generateSql(SqlDialect dialect) { new JdbcImplementor(dialect, (JavaTypeFactory) getCluster().getTypeFactory()); final JdbcImplementor.Result result = - jdbcImplementor.visitInput(this, 0); + jdbcImplementor.visitRoot(this.getInput()); return result.asStatement().toSqlString(dialect); } } diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java index 1773a280684..48575ba831c 100644 --- a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java +++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java @@ -19,6 +19,8 @@ import org.apache.calcite.linq4j.Ord; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; @@ -29,6 +31,7 @@ import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.Window; +import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -161,8 +164,20 @@ protected SqlImplementor(SqlDialect dialect) { /** Visits a relational expression that has no parent. */ public final Result visitRoot(RelNode r) { + RelNode best; + if (!this.dialect.supportsGroupByLiteral()) { + HepProgramBuilder hepProgramBuilder = new HepProgramBuilder(); + hepProgramBuilder.addRuleInstance( + AggregateProjectConstantToDummyJoinRule.Config.DEFAULT.toRule()); + HepPlanner hepPlanner = new HepPlanner(hepProgramBuilder.build()); + + hepPlanner.setRoot(r); + best = hepPlanner.findBestExp(); + } else { + best = r; + } try { - return visitInput(holder(r), 0); + return visitInput(holder(best), 0); } catch (Error | RuntimeException e) { throw Util.throwAsRuntime("Error while converting RelNode to SqlNode:\n" + RelOptUtil.toString(r), e); diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectConstantToDummyJoinRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectConstantToDummyJoinRule.java new file mode 100644 index 00000000000..abfb7f8a84d --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectConstantToDummyJoinRule.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.tools.RelBuilder; + +import com.google.common.collect.ImmutableList; + +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.List; + +/** + * Planner rule that recognizes a {@link org.apache.calcite.rel.core.Aggregate} + * on top of a {@link org.apache.calcite.rel.core.Project} where the aggregate's group set + * contains literals (true, false, DATE, chars, etc), and removes the literals from the + * group keys by joining with a dummy table of literals. + * + *
{@code
+ * select avg(sal)
+ * from emp
+ * group by true, DATE '2022-01-01';
+ * }
+ * becomes + *
{@code
+ * select avg(sal)
+ * from emp, (select true x, DATE '2022-01-01' d) dummy
+ * group by dummy.x, dummy.d;
+ * }
+ */ +@Value.Enclosing +public final class AggregateProjectConstantToDummyJoinRule + extends RelRule { + + /** Creates an AggregateProjectConstantToDummyJoinRule. */ + private AggregateProjectConstantToDummyJoinRule(Config config) { + super(config); + } + + @Override public boolean matches(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + final Project project = call.rel(1); + + for (int groupKey: aggregate.getGroupSet().asList()) { + if (groupKey >= aggregate.getRowType().getFieldCount()) { + continue; + } + RexNode groupKeyProject = project.getProjects().get(groupKey); + if (groupKeyProject instanceof RexLiteral) { + return true; + } + } + + return false; + } + + @Override public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + final Project project = call.rel(1); + + RelBuilder builder = call.builder(); + RexBuilder rexBuilder = builder.getRexBuilder(); + + builder.push(project.getInput()); + int offset = project.getInput().getRowType().getFieldCount(); + + RelDataTypeFactory.Builder valuesType = rexBuilder.getTypeFactory().builder(); + List literals = new ArrayList<>(); + List projects = project.getProjects(); + for (int i = 0; i < projects.size(); i++) { + RexNode node = projects.get(i); + if (node instanceof RexLiteral) { + literals.add((RexLiteral) node); + valuesType.add(project.getRowType().getFieldList().get(i)); + } + } + builder.values(ImmutableList.of(literals), valuesType.build()); + + builder.join(JoinRelType.INNER, rexBuilder.makeLiteral(true)); + + List newProjects = new ArrayList<>(); + int literalCounter = 0; + for (RexNode exp : project.getProjects()) { + if (exp instanceof RexLiteral) { + newProjects.add(builder.field(offset + literalCounter++)); + } else { + newProjects.add(exp); + } + } + + builder.project(newProjects); + builder.aggregate( + builder.groupKey( + aggregate.getGroupSet(), aggregate.getGroupSets()), aggregate.getAggCallList() + ); + + call.transformTo(builder.build()); + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + Config DEFAULT = ImmutableAggregateProjectConstantToDummyJoinRule.Config.of() + .withOperandFor(Aggregate.class, Project.class); + + @Override default AggregateProjectConstantToDummyJoinRule toRule() { + return new AggregateProjectConstantToDummyJoinRule(this); + } + + default Config withOperandFor(Class aggregateClass, + Class projectClass) { + return withOperandSupplier(b0 -> + b0.operand(aggregateClass).oneInput(b1 -> + b1.operand(projectClass).anyInputs())) + .as(Config.class); + } + } +} diff --git a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java index a6a735764fa..26a998baeae 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java @@ -711,6 +711,26 @@ public boolean supportsCharSet() { return true; } + /** + * Returns whether the dialect supports GROUP BY literals. + * + *

For instance, in {@link DatabaseProduct#REDSHIFT}, the following queries are illegal.

+ *
{@code
+   * select avg(salary)
+   * from emp
+   * group by true
+   * }
+ * + *
{@code
+   * select avg(salary)
+   * from emp
+   * group by 'a', DATE '2022-01-01'
+   * }
+ */ + public boolean supportsGroupByLiteral() { + return true; + } + public boolean supportsAggregateFunction(SqlKind kind) { switch (kind) { case COUNT: diff --git a/core/src/main/java/org/apache/calcite/sql/dialect/InformixSqlDialect.java b/core/src/main/java/org/apache/calcite/sql/dialect/InformixSqlDialect.java index 715ba826107..eb9acbdb59f 100644 --- a/core/src/main/java/org/apache/calcite/sql/dialect/InformixSqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/dialect/InformixSqlDialect.java @@ -31,4 +31,12 @@ public class InformixSqlDialect extends SqlDialect { public InformixSqlDialect(Context context) { super(context); } + + @Override public boolean supportsGroupByLiteral() { + return false; + } + + @Override public boolean supportsAliasedValues() { + return false; + } } diff --git a/core/src/main/java/org/apache/calcite/sql/dialect/PostgresqlSqlDialect.java b/core/src/main/java/org/apache/calcite/sql/dialect/PostgresqlSqlDialect.java index 87137c82989..db225546217 100644 --- a/core/src/main/java/org/apache/calcite/sql/dialect/PostgresqlSqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/dialect/PostgresqlSqlDialect.java @@ -137,4 +137,8 @@ public PostgresqlSqlDialect(Context context) { super.unparseCall(writer, call, leftPrec, rightPrec); } } + + @Override public boolean supportsGroupByLiteral() { + return false; + } } diff --git a/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java b/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java index 0d498de0d80..4e94977ce30 100644 --- a/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java +++ b/core/src/main/java/org/apache/calcite/sql/dialect/RedshiftSqlDialect.java @@ -106,4 +106,12 @@ public RedshiftSqlDialect(Context context) { new SqlUserDefinedTypeNameSpec(castSpec, SqlParserPos.ZERO), SqlParserPos.ZERO); } + + @Override public boolean supportsGroupByLiteral() { + return false; + } + + @Override public boolean supportsAliasedValues() { + return false; + } } diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java index 7d0742c6d2a..8bc7c82f4fe 100644 --- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java @@ -219,6 +219,30 @@ private static String toSql(RelNode root, SqlDialect dialect, .getSql(); } + @Test void testGroupByBooleanLiteral() { + String query = "select avg(\"salary\") from \"employee\" group by true"; + String expectedRedshift = "SELECT AVG(\"employee\".\"salary\")\n" + + "FROM \"foodmart\".\"employee\",\n" + + "(SELECT TRUE AS \"$f0\") AS \"t\"\nGROUP BY \"t\".\"$f0\""; + String expectedInformix = "SELECT AVG(employee.salary)\nFROM foodmart.employee," + + "\n(SELECT TRUE AS $f0) AS t\nGROUP BY t.$f0"; + sql(query) + .withRedshift().ok(expectedRedshift) + .withInformix().ok(expectedInformix); + } + + @Test void testGroupByDateLiteral() { + String query = "select avg(\"salary\") from \"employee\" group by DATE '2022-01-01'"; + String expectedRedshift = "SELECT AVG(\"employee\".\"salary\")\n" + + "FROM \"foodmart\".\"employee\",\n" + + "(SELECT DATE '2022-01-01' AS \"$f0\") AS \"t\"\nGROUP BY \"t\".\"$f0\""; + String expectedInformix = "SELECT AVG(employee.salary)\nFROM foodmart.employee," + + "\n(SELECT DATE '2022-01-01' AS $f0) AS t\nGROUP BY t.$f0"; + sql(query) + .withRedshift().ok(expectedRedshift) + .withInformix().ok(expectedInformix); + } + @Test void testSimpleSelectStarFromProductTable() { String query = "select * from \"product\""; String expected = "SELECT *\n" @@ -5083,7 +5107,9 @@ private void checkLiteral2(String expression, String expected) { + "UNION ALL\n" + "SELECT 2 AS a, 'yy' AS b)"; final String expectedSnowflake = expectedPostgresql; - final String expectedRedshift = expectedPostgresql; + final String expectedRedshift = "SELECT \"a\"\n" + + "FROM (SELECT 1 AS \"a\", 'x ' AS \"b\"\n" + + "UNION ALL\nSELECT 2 AS \"a\", 'yy' AS \"b\")"; sql(sql) .withClickHouse().ok(expectedClickHouse) .withBigQuery().ok(expectedBigQuery) @@ -6358,6 +6384,10 @@ Sql withRedshift() { return dialect(DatabaseProduct.REDSHIFT.getDialect()); } + Sql withInformix() { + return dialect(DatabaseProduct.INFORMIX.getDialect()); + } + Sql withSnowflake() { return dialect(DatabaseProduct.SNOWFLAKE.getDialect()); } diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index 219c3d91056..a6bddd14529 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -63,6 +63,7 @@ import org.apache.calcite.rel.logical.LogicalTableModify; import org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule; import org.apache.calcite.rel.rules.AggregateExtractProjectRule; +import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule; import org.apache.calcite.rel.rules.AggregateProjectMergeRule; import org.apache.calcite.rel.rules.AggregateProjectPullUpConstantsRule; import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; @@ -200,6 +201,33 @@ private static boolean skipItem(RexNode expr) { && "item".equalsIgnoreCase(((RexCall) expr).getOperator().getName()); } + @Test void testGroupByDateLiteralSimple() { + final String query = "select avg(sal)\n" + + "from emp\n" + + "group by DATE '2022-01-01'"; + sql(query) + .withRule(AggregateProjectConstantToDummyJoinRule.Config.DEFAULT.toRule()) + .check(); + } + + @Test void testGroupByBooleanLiteralSimple() { + final String query = "select avg(sal)\n" + + "from emp\n" + + "group by true"; + sql(query) + .withRule(AggregateProjectConstantToDummyJoinRule.Config.DEFAULT.toRule()) + .check(); + } + + @Test void testGroupByMultipleLiterals() { + final String query = "select avg(sal)\n" + + "from emp\n" + + "group by false, deptno, true, true, empno, false, 'ab', DATE '2022-01-01'"; + sql(query) + .withRule(AggregateProjectConstantToDummyJoinRule.Config.DEFAULT.toRule()) + .check(); + } + @Test void testReduceNot() { HepProgramBuilder builder = new HepProgramBuilder(); builder.addRuleClass(ReduceExpressionsRule.class); diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index b5ac46dca5e..1614355031f 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -4033,6 +4033,81 @@ LogicalProject(EXPR$0=[1]) LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) LogicalFilter(condition=[>($5, 100)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/spark/src/main/java/org/apache/calcite/adapter/spark/JdbcToSparkConverter.java b/spark/src/main/java/org/apache/calcite/adapter/spark/JdbcToSparkConverter.java index 9f371243ad9..557d4ec739d 100644 --- a/spark/src/main/java/org/apache/calcite/adapter/spark/JdbcToSparkConverter.java +++ b/spark/src/main/java/org/apache/calcite/adapter/spark/JdbcToSparkConverter.java @@ -115,7 +115,7 @@ private String generateSql(SqlDialect dialect) { new JdbcImplementor(dialect, (JavaTypeFactory) getCluster().getTypeFactory()); final JdbcImplementor.Result result = - jdbcImplementor.visitInput(this, 0); + jdbcImplementor.visitRoot(this.getInput()); return result.asStatement().toSqlString(dialect).getSql(); } }