From bb4a0638260b6103994b14e1071eb679ae4bb5c2 Mon Sep 17 00:00:00 2001 From: snuyanzin Date: Tue, 25 Dec 2018 19:51:03 +0300 Subject: [PATCH] [CALCITE-2754] Implement LISTAGG function (Sergey Nuyanzin, Chunwei Lei) Close #1142 --- .../adapter/enumerable/RexImpTable.java | 29 ++++++++++ .../java/org/apache/calcite/sql/SqlKind.java | 5 +- .../calcite/sql/fun/SqlStdOperatorTable.java | 14 +++++ .../calcite/sql/validate/AggVisitor.java | 5 ++ .../calcite/sql/test/SqlOperatorBaseTest.java | 21 +++++++ core/src/test/resources/sql/agg.iq | 56 +++++++++++++++++++ site/_docs/reference.md | 2 +- 7 files changed, 130 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java index fd535cae661..157e10df703 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java @@ -184,6 +184,7 @@ import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LESS_THAN; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LESS_THAN_OR_EQUAL; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LIKE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LISTAGG; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LN; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LOCALTIME; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LOCALTIMESTAMP; @@ -257,6 +258,8 @@ public class RexImpTable { Expressions.constant(false); public static final ConstantExpression TRUE_EXPR = Expressions.constant(true); + public static final ConstantExpression COMMA_EXPR = + Expressions.constant(","); public static final MemberExpression BOXED_FALSE_EXPR = Expressions.field(null, Boolean.class, "FALSE"); public static final MemberExpression BOXED_TRUE_EXPR = @@ -535,6 +538,7 @@ public Expression implement(RexToLixTranslator translator, aggMap.put(BIT_OR, bitop); aggMap.put(SINGLE_VALUE, constructorSupplier(SingleValueImplementor.class)); aggMap.put(COLLECT, constructorSupplier(CollectImplementor.class)); + aggMap.put(LISTAGG, constructorSupplier(ListaggImplementor.class)); aggMap.put(FUSION, constructorSupplier(FusionImplementor.class)); final Supplier grouping = constructorSupplier(GroupingImplementor.class); @@ -1370,6 +1374,31 @@ static class CollectImplementor extends StrictAggImplementor { } } + /** Implementor for the {@code LISTAGG} aggregate function. */ + static class ListaggImplementor extends StrictAggImplementor { + @Override protected void implementNotNullReset(AggContext info, + AggResetContext reset) { + reset.currentBlock().add( + Expressions.statement( + Expressions.assign(reset.accumulator().get(0), NULL_EXPR))); + } + + @Override public void implementNotNullAdd(AggContext info, + AggAddContext add) { + final Expression accValue = add.accumulator().get(0); + final Expression arg0 = add.arguments().get(0); + final Expression arg1 = add.arguments().size() == 2 + ? add.arguments().get(1) : COMMA_EXPR; + final Expression result = Expressions.condition( + Expressions.equal(NULL_EXPR, accValue), + arg0, + Expressions.call(BuiltInMethod.STRING_CONCAT.method, accValue, + Expressions.call(BuiltInMethod.STRING_CONCAT.method, arg1, arg0))); + + add.currentBlock().add(Expressions.statement(Expressions.assign(accValue, result))); + } + } + /** Implementor for the {@code FUSION} aggregate function. */ static class FusionImplementor extends StrictAggImplementor { @Override protected void implementNotNullReset(AggContext info, diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java index eeda7170af8..9b836b99749 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java @@ -931,6 +931,9 @@ public enum SqlKind { /** The {@code NTH_VALUE} aggregate function. */ NTH_VALUE, + /** The {@code LISTAGG} aggregate function. */ + LISTAGG, + /** The {@code COLLECT} aggregate function. */ COLLECT, @@ -1125,7 +1128,7 @@ public enum SqlKind { LAST_VALUE, COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY, AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP, NTILE, COLLECT, FUSION, SINGLE_VALUE, ROW_NUMBER, RANK, PERCENT_RANK, DENSE_RANK, - CUME_DIST, JSON_ARRAYAGG, JSON_OBJECTAGG, BIT_AND, BIT_OR); + CUME_DIST, JSON_ARRAYAGG, JSON_OBJECTAGG, BIT_AND, BIT_OR, LISTAGG); /** * Category consisting of all DML operators. diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java index 0e9d563441e..8654f7d2654 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java @@ -2158,6 +2158,20 @@ public boolean argumentMustBeScalar(int ordinal) { Optionality.OPTIONAL) { }; + /** + * The LISTAGG operator. Multiset aggregator function. + */ + public static final SqlAggFunction LISTAGG = + new SqlAggFunction("LISTAGG", + null, + SqlKind.LISTAGG, + ReturnTypes.ARG0_NULLABLE, + null, + OperandTypes.or(OperandTypes.STRING, OperandTypes.STRING_STRING), + SqlFunctionCategory.SYSTEM, false, false, + Optionality.OPTIONAL) { + }; + /** * The FUSION operator. Multiset aggregator function. */ diff --git a/core/src/main/java/org/apache/calcite/sql/validate/AggVisitor.java b/core/src/main/java/org/apache/calcite/sql/validate/AggVisitor.java index b6f2c34a6a6..1a78785e8a9 100644 --- a/core/src/main/java/org/apache/calcite/sql/validate/AggVisitor.java +++ b/core/src/main/java/org/apache/calcite/sql/validate/AggVisitor.java @@ -105,6 +105,11 @@ public Void visit(SqlCall call) { // don't traverse into queries return null; } + if (call.getKind() == SqlKind.WITHIN_GROUP) { + if (aggregate) { + return found(call); + } + } if (call.getKind() == SqlKind.OVER) { if (over) { return found(call); diff --git a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java index 4f49fb83443..4c77ddee78e 100644 --- a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java +++ b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java @@ -6195,6 +6195,27 @@ protected static Pair currentTimeString(TimeZone tz) { tester.checkAgg("collect(DISTINCT x)", values, 2, (double) 0); } + @Test public void testListaggFunc() { + tester.setFor(SqlStdOperatorTable.LISTAGG, VM_FENNEL, VM_JAVA); + tester.checkFails("listagg(^*^)", "Unknown identifier '\\*'", false); + tester.checkFails("^listagg(12)^", + "Cannot apply 'LISTAGG' to arguments of type .*'\n.*'", false); + tester.checkFails("^listagg(cast(12 as double))^", + "Cannot apply 'LISTAGG' to arguments of type .*'\n.*'", false); + tester.checkFails("^listagg()^", + "Invalid number of arguments to function 'LISTAGG'. Was expecting 1 arguments", + false); + tester.checkFails("^listagg('1', '2', '3')^", + "Invalid number of arguments to function 'LISTAGG'. Was expecting 1 arguments", + false); + checkAggType(tester, "listagg('test')", "CHAR(4) NOT NULL"); + checkAggType(tester, "listagg('test', ', ')", "CHAR(4) NOT NULL"); + final String[] values1 = {"'hello'", "CAST(null AS CHAR)", "'world'", "'!'"}; + tester.checkAgg("listagg(x)", values1, "hello,world,!", (double) 0); + final String[] values2 = {"0", "1", "2", "3"}; + tester.checkAgg("listagg(cast(x as CHAR))", values2, "0,1,2,3", (double) 0); + } + @Test public void testFusionFunc() { tester.setFor(SqlStdOperatorTable.FUSION, VM_FENNEL, VM_JAVA); } diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq index 41c1e9541aa..c3fda247f0f 100644 --- a/core/src/test/resources/sql/agg.iq +++ b/core/src/test/resources/sql/agg.iq @@ -2699,4 +2699,60 @@ EnumerableAggregate(group=[{0}], EXPR$1=[JSON_OBJECTAGG_NULL_ON_NULL($1, $2)], E EnumerableValues(tuples=[[{ 0 }]]) !plan +select listagg(ename) as combined_name from emp; ++------------------------------------------------+ +| COMBINED_NAME | ++------------------------------------------------+ +| Jane,Bob,Eric,Susan,Alice,Adam,Eve,Grace,Wilma | ++------------------------------------------------+ +(1 row) + +!ok + +select listagg(ename) within group(order by gender, ename) as combined_name from emp; ++------------------------------------------------+ +| COMBINED_NAME | ++------------------------------------------------+ +| Alice,Eve,Grace,Jane,Susan,Wilma,Adam,Bob,Eric | ++------------------------------------------------+ +(1 row) + +!ok + +EnumerableAggregate(group=[{}], COMBINED_NAME=[LISTAGG($0) WITHIN GROUP ([1, 0])]) + EnumerableUnion(all=[true]) + EnumerableCalc(expr#0=[{inputs}], expr#1=['Jane'], expr#2=['F'], EXPR$0=[$t1], EXPR$2=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) + EnumerableCalc(expr#0=[{inputs}], expr#1=['Bob'], expr#2=['M'], EXPR$0=[$t1], EXPR$2=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) + EnumerableCalc(expr#0=[{inputs}], expr#1=['Eric'], expr#2=['M'], EXPR$0=[$t1], EXPR$2=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) + EnumerableCalc(expr#0=[{inputs}], expr#1=['Susan'], expr#2=['F'], EXPR$0=[$t1], EXPR$2=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) + EnumerableCalc(expr#0=[{inputs}], expr#1=['Alice'], expr#2=['F'], EXPR$0=[$t1], EXPR$2=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) + EnumerableCalc(expr#0=[{inputs}], expr#1=['Adam'], expr#2=['M'], EXPR$0=[$t1], EXPR$2=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) + EnumerableCalc(expr#0=[{inputs}], expr#1=['Eve'], expr#2=['F'], EXPR$0=[$t1], EXPR$2=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) + EnumerableCalc(expr#0=[{inputs}], expr#1=['Grace'], expr#2=['F'], EXPR$0=[$t1], EXPR$2=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) + EnumerableCalc(expr#0=[{inputs}], expr#1=['Wilma'], expr#2=['F'], EXPR$0=[$t1], EXPR$2=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) +!plan + +select + listagg(ename) within group(order by deptno, ename) as default_listagg_sep, + listagg(ename, '; ') within group(order by deptno, ename desc) as custom_listagg_sep +from emp group by gender; ++----------------------------------+---------------------------------------+ +| DEFAULT_LISTAGG_SEP | CUSTOM_LISTAGG_SEP | ++----------------------------------+---------------------------------------+ +| Bob,Eric,Adam | Bob; Eric; Adam | +| Jane,Alice,Susan,Eve,Grace,Wilma | Jane; Susan; Alice; Eve; Grace; Wilma | ++----------------------------------+---------------------------------------+ +(2 rows) + +!ok + # End agg.iq diff --git a/site/_docs/reference.md b/site/_docs/reference.md index 62a2108600f..5868e8967f2 100644 --- a/site/_docs/reference.md +++ b/site/_docs/reference.md @@ -1535,6 +1535,7 @@ and `LISTAGG`). | Operator syntax | Description |:---------------------------------- |:----------- | COLLECT( [ ALL | DISTINCT ] value) | Returns a multiset of the values +| LISTAGG( [ ALL | DISTINCT ] value [, separator]) | Returns values concatenated into a string, delimited by separator (default ',') | COUNT( [ ALL | DISTINCT ] value [, value ]*) | Returns the number of input rows for which *value* is not null (wholly not null if *value* is composite) | COUNT(*) | Returns the number of input rows | FUSION(multiset) | Returns the multiset union of *multiset* across all input values @@ -1558,7 +1559,6 @@ and `LISTAGG`). Not implemented: -* LISTAGG(string) * REGR_AVGX(numeric1, numeric2) * REGR_AVGY(numeric1, numeric2) * REGR_INTERCEPT(numeric1, numeric2)