Skip to content

Commit

Permalink
[CALCITE-4702] Error when executing query with GROUP BY constant via …
Browse files Browse the repository at this point in the history
…JDBC adapter

Add new method in SqlDialect controlling whether GROUP BY using
literals is supported. Note that the whole Postgres family returns
false by precaution; some literals may be supported by Postgres or some
derivation of it. We agreed the extra complexity needed handle those
special cases was not worth it so we decided to return false for all
kinds of literals.

Introduce a new rule to rewrite the GROUP BY using an inner join with a
dummy table for those dialects that do not support literals.

Add a rule based transformation step at the beginning of the rel to SQL
conversion and ensure callers are passing from there. This allows to
keep the aggregate constant transformation in a single place.

Add tests with GROUP BY and different types of literals.

Close apache#2482
  • Loading branch information
soumyakanti3578 authored and liyafan82 committed Mar 4, 2022
1 parent bb89b92 commit 466fb42
Show file tree
Hide file tree
Showing 11 changed files with 334 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ private SqlString generateSql(SqlDialect dialect) {
new JdbcImplementor(dialect,
(JavaTypeFactory) getCluster().getTypeFactory());
final JdbcImplementor.Result result =
jdbcImplementor.visitInput(this, 0);
jdbcImplementor.visitRoot(this.getInput());
return result.asStatement().toSqlString(dialect);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
Expand All @@ -29,6 +31,7 @@
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Window;
import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
Expand Down Expand Up @@ -161,8 +164,20 @@ protected SqlImplementor(SqlDialect dialect) {

/** Visits a relational expression that has no parent. */
public final Result visitRoot(RelNode r) {
RelNode best;
if (!this.dialect.supportsGroupByLiteral()) {
HepProgramBuilder hepProgramBuilder = new HepProgramBuilder();
hepProgramBuilder.addRuleInstance(
AggregateProjectConstantToDummyJoinRule.Config.DEFAULT.toRule());
HepPlanner hepPlanner = new HepPlanner(hepProgramBuilder.build());

hepPlanner.setRoot(r);
best = hepPlanner.findBestExp();
} else {
best = r;
}
try {
return visitInput(holder(r), 0);
return visitInput(holder(best), 0);
} catch (Error | RuntimeException e) {
throw Util.throwAsRuntime("Error while converting RelNode to SqlNode:\n"
+ RelOptUtil.toString(r), e);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;

import com.google.common.collect.ImmutableList;

import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.List;

/**
* Planner rule that recognizes a {@link org.apache.calcite.rel.core.Aggregate}
* on top of a {@link org.apache.calcite.rel.core.Project} where the aggregate's group set
* contains literals (true, false, DATE, chars, etc), and removes the literals from the
* group keys by joining with a dummy table of literals.
*
* <pre>{@code
* select avg(sal)
* from emp
* group by true, DATE '2022-01-01';
* }</pre>
* becomes
* <pre>{@code
* select avg(sal)
* from emp, (select true x, DATE '2022-01-01' d) dummy
* group by dummy.x, dummy.d;
* }</pre>
*/
@Value.Enclosing
public final class AggregateProjectConstantToDummyJoinRule
extends RelRule<AggregateProjectConstantToDummyJoinRule.Config> {

/** Creates an AggregateProjectConstantToDummyJoinRule. */
private AggregateProjectConstantToDummyJoinRule(Config config) {
super(config);
}

@Override public boolean matches(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);

for (int groupKey: aggregate.getGroupSet().asList()) {
if (groupKey >= aggregate.getRowType().getFieldCount()) {
continue;
}
RexNode groupKeyProject = project.getProjects().get(groupKey);
if (groupKeyProject instanceof RexLiteral) {
return true;
}
}

return false;
}

@Override public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);

RelBuilder builder = call.builder();
RexBuilder rexBuilder = builder.getRexBuilder();

builder.push(project.getInput());
int offset = project.getInput().getRowType().getFieldCount();

RelDataTypeFactory.Builder valuesType = rexBuilder.getTypeFactory().builder();
List<RexLiteral> literals = new ArrayList<>();
List<RexNode> projects = project.getProjects();
for (int i = 0; i < projects.size(); i++) {
RexNode node = projects.get(i);
if (node instanceof RexLiteral) {
literals.add((RexLiteral) node);
valuesType.add(project.getRowType().getFieldList().get(i));
}
}
builder.values(ImmutableList.of(literals), valuesType.build());

builder.join(JoinRelType.INNER, rexBuilder.makeLiteral(true));

List<RexNode> newProjects = new ArrayList<>();
int literalCounter = 0;
for (RexNode exp : project.getProjects()) {
if (exp instanceof RexLiteral) {
newProjects.add(builder.field(offset + literalCounter++));
} else {
newProjects.add(exp);
}
}

builder.project(newProjects);
builder.aggregate(
builder.groupKey(
aggregate.getGroupSet(), aggregate.getGroupSets()), aggregate.getAggCallList()
);

call.transformTo(builder.build());
}

/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableAggregateProjectConstantToDummyJoinRule.Config.of()
.withOperandFor(Aggregate.class, Project.class);

@Override default AggregateProjectConstantToDummyJoinRule toRule() {
return new AggregateProjectConstantToDummyJoinRule(this);
}

default Config withOperandFor(Class<? extends Aggregate> aggregateClass,
Class<? extends Project> projectClass) {
return withOperandSupplier(b0 ->
b0.operand(aggregateClass).oneInput(b1 ->
b1.operand(projectClass).anyInputs()))
.as(Config.class);
}
}
}
20 changes: 20 additions & 0 deletions core/src/main/java/org/apache/calcite/sql/SqlDialect.java
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,26 @@ public boolean supportsCharSet() {
return true;
}

/**
* Returns whether the dialect supports GROUP BY literals.
*
* <p>For instance, in {@link DatabaseProduct#REDSHIFT}, the following queries are illegal.</p>
* <pre>{@code
* select avg(salary)
* from emp
* group by true
* }</pre>
*
* <pre>{@code
* select avg(salary)
* from emp
* group by 'a', DATE '2022-01-01'
* }</pre>
*/
public boolean supportsGroupByLiteral() {
return true;
}

public boolean supportsAggregateFunction(SqlKind kind) {
switch (kind) {
case COUNT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,12 @@ public class InformixSqlDialect extends SqlDialect {
public InformixSqlDialect(Context context) {
super(context);
}

@Override public boolean supportsGroupByLiteral() {
return false;
}

@Override public boolean supportsAliasedValues() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,8 @@ public PostgresqlSqlDialect(Context context) {
super.unparseCall(writer, call, leftPrec, rightPrec);
}
}

@Override public boolean supportsGroupByLiteral() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,12 @@ public RedshiftSqlDialect(Context context) {
new SqlUserDefinedTypeNameSpec(castSpec, SqlParserPos.ZERO),
SqlParserPos.ZERO);
}

@Override public boolean supportsGroupByLiteral() {
return false;
}

@Override public boolean supportsAliasedValues() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,30 @@ private static String toSql(RelNode root, SqlDialect dialect,
.getSql();
}

@Test void testGroupByBooleanLiteral() {
String query = "select avg(\"salary\") from \"employee\" group by true";
String expectedRedshift = "SELECT AVG(\"employee\".\"salary\")\n"
+ "FROM \"foodmart\".\"employee\",\n"
+ "(SELECT TRUE AS \"$f0\") AS \"t\"\nGROUP BY \"t\".\"$f0\"";
String expectedInformix = "SELECT AVG(employee.salary)\nFROM foodmart.employee,"
+ "\n(SELECT TRUE AS $f0) AS t\nGROUP BY t.$f0";
sql(query)
.withRedshift().ok(expectedRedshift)
.withInformix().ok(expectedInformix);
}

@Test void testGroupByDateLiteral() {
String query = "select avg(\"salary\") from \"employee\" group by DATE '2022-01-01'";
String expectedRedshift = "SELECT AVG(\"employee\".\"salary\")\n"
+ "FROM \"foodmart\".\"employee\",\n"
+ "(SELECT DATE '2022-01-01' AS \"$f0\") AS \"t\"\nGROUP BY \"t\".\"$f0\"";
String expectedInformix = "SELECT AVG(employee.salary)\nFROM foodmart.employee,"
+ "\n(SELECT DATE '2022-01-01' AS $f0) AS t\nGROUP BY t.$f0";
sql(query)
.withRedshift().ok(expectedRedshift)
.withInformix().ok(expectedInformix);
}

@Test void testSimpleSelectStarFromProductTable() {
String query = "select * from \"product\"";
String expected = "SELECT *\n"
Expand Down Expand Up @@ -5083,7 +5107,9 @@ private void checkLiteral2(String expression, String expected) {
+ "UNION ALL\n"
+ "SELECT 2 AS a, 'yy' AS b)";
final String expectedSnowflake = expectedPostgresql;
final String expectedRedshift = expectedPostgresql;
final String expectedRedshift = "SELECT \"a\"\n"
+ "FROM (SELECT 1 AS \"a\", 'x ' AS \"b\"\n"
+ "UNION ALL\nSELECT 2 AS \"a\", 'yy' AS \"b\")";
sql(sql)
.withClickHouse().ok(expectedClickHouse)
.withBigQuery().ok(expectedBigQuery)
Expand Down Expand Up @@ -6358,6 +6384,10 @@ Sql withRedshift() {
return dialect(DatabaseProduct.REDSHIFT.getDialect());
}

Sql withInformix() {
return dialect(DatabaseProduct.INFORMIX.getDialect());
}

Sql withSnowflake() {
return dialect(DatabaseProduct.SNOWFLAKE.getDialect());
}
Expand Down
28 changes: 28 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.apache.calcite.rel.logical.LogicalTableModify;
import org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule;
import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule;
import org.apache.calcite.rel.rules.AggregateProjectMergeRule;
import org.apache.calcite.rel.rules.AggregateProjectPullUpConstantsRule;
import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule;
Expand Down Expand Up @@ -200,6 +201,33 @@ private static boolean skipItem(RexNode expr) {
&& "item".equalsIgnoreCase(((RexCall) expr).getOperator().getName());
}

@Test void testGroupByDateLiteralSimple() {
final String query = "select avg(sal)\n"
+ "from emp\n"
+ "group by DATE '2022-01-01'";
sql(query)
.withRule(AggregateProjectConstantToDummyJoinRule.Config.DEFAULT.toRule())
.check();
}

@Test void testGroupByBooleanLiteralSimple() {
final String query = "select avg(sal)\n"
+ "from emp\n"
+ "group by true";
sql(query)
.withRule(AggregateProjectConstantToDummyJoinRule.Config.DEFAULT.toRule())
.check();
}

@Test void testGroupByMultipleLiterals() {
final String query = "select avg(sal)\n"
+ "from emp\n"
+ "group by false, deptno, true, true, empno, false, 'ab', DATE '2022-01-01'";
sql(query)
.withRule(AggregateProjectConstantToDummyJoinRule.Config.DEFAULT.toRule())
.check();
}

@Test void testReduceNot() {
HepProgramBuilder builder = new HepProgramBuilder();
builder.addRuleClass(ReduceExpressionsRule.class);
Expand Down
Loading

0 comments on commit 466fb42

Please sign in to comment.