Skip to content

Commit e545811

Browse files
dilipbiswalcloud-fan
authored andcommitted
[SPARK-19851][SQL] Add support for EVERY and ANY (SOME) aggregates
## What changes were proposed in this pull request? Implements Every, Some, Any aggregates in SQL. These new aggregate expressions are analyzed in normal way and rewritten to equivalent existing aggregate expressions in the optimizer. Every(x) => Min(x) where x is boolean. Some(x) => Max(x) where x is boolean. Any is a synonym for Some. SQL ``` explain extended select every(v) from test_agg group by k; ``` Plan : ``` == Parsed Logical Plan == 'Aggregate ['k], [unresolvedalias('every('v), None)] +- 'UnresolvedRelation `test_agg` == Analyzed Logical Plan == every(v): boolean Aggregate [k#0], [every(v#1) AS every(v)#5] +- SubqueryAlias `test_agg` +- Project [k#0, v#1] +- SubqueryAlias `test_agg` +- LocalRelation [k#0, v#1] == Optimized Logical Plan == Aggregate [k#0], [min(v#1) AS every(v)#5] +- LocalRelation [k#0, v#1] == Physical Plan == *(2) HashAggregate(keys=[k#0], functions=[min(v#1)], output=[every(v)#5]) +- Exchange hashpartitioning(k#0, 200) +- *(1) HashAggregate(keys=[k#0], functions=[partial_min(v#1)], output=[k#0, min#7]) +- LocalTableScan [k#0, v#1] Time taken: 0.512 seconds, Fetched 1 row(s) ``` ## How was this patch tested? Added tests in SQLQueryTestSuite, DataframeAggregateSuite Closes #22809 from dilipbiswal/SPARK-19851-specific-rewrite. Authored-by: Dilip Biswal <dbiswal@us.ibm.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 41e1416 commit e545811

File tree

7 files changed

+388
-4
lines changed

7 files changed

+388
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ object FunctionRegistry {
300300
expression[CollectList]("collect_list"),
301301
expression[CollectSet]("collect_set"),
302302
expression[CountMinSketchAgg]("count_min_sketch"),
303+
expression[EveryAgg]("every"),
304+
expression[AnyAgg]("any"),
305+
expression[SomeAgg]("some"),
303306

304307
// string functions
305308
expression[Ascii]("ascii"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.util.Locale
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
24+
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
2425
import org.apache.spark.sql.catalyst.expressions.codegen._
2526
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2627
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -282,6 +283,31 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable {
282283
override lazy val canonicalized: Expression = child.canonicalized
283284
}
284285

286+
/**
287+
* An aggregate expression that gets rewritten (currently by the optimizer) into a
288+
* different aggregate expression for evaluation. This is mainly used to provide compatibility
289+
* with other databases. For example, we use this to support every, any/some aggregates by rewriting
290+
* them with Min and Max respectively.
291+
*/
292+
trait UnevaluableAggregate extends DeclarativeAggregate {
293+
294+
override def nullable: Boolean = true
295+
296+
override lazy val aggBufferAttributes =
297+
throw new UnsupportedOperationException(s"Cannot evaluate aggBufferAttributes: $this")
298+
299+
override lazy val initialValues: Seq[Expression] =
300+
throw new UnsupportedOperationException(s"Cannot evaluate initialValues: $this")
301+
302+
override lazy val updateExpressions: Seq[Expression] =
303+
throw new UnsupportedOperationException(s"Cannot evaluate updateExpressions: $this")
304+
305+
override lazy val mergeExpressions: Seq[Expression] =
306+
throw new UnsupportedOperationException(s"Cannot evaluate mergeExpressions: $this")
307+
308+
override lazy val evaluateExpression: Expression =
309+
throw new UnsupportedOperationException(s"Cannot evaluate evaluateExpression: $this")
310+
}
285311

286312
/**
287313
* Expressions that don't have SQL representation should extend this trait. Examples are
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.types._
23+
24+
abstract class UnevaluableBooleanAggBase(arg: Expression)
25+
extends UnevaluableAggregate with ImplicitCastInputTypes {
26+
27+
override def children: Seq[Expression] = arg :: Nil
28+
29+
override def dataType: DataType = BooleanType
30+
31+
override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)
32+
33+
override def checkInputDataTypes(): TypeCheckResult = {
34+
arg.dataType match {
35+
case dt if dt != BooleanType =>
36+
TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " +
37+
s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].")
38+
case _ => TypeCheckResult.TypeCheckSuccess
39+
}
40+
}
41+
}
42+
43+
@ExpressionDescription(
44+
usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.",
45+
since = "3.0.0")
46+
case class EveryAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
47+
override def nodeName: String = "Every"
48+
}
49+
50+
@ExpressionDescription(
51+
usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.",
52+
since = "3.0.0")
53+
case class AnyAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
54+
override def nodeName: String = "Any"
55+
}
56+
57+
@ExpressionDescription(
58+
usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.",
59+
since = "3.0.0")
60+
case class SomeAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) {
61+
override def nodeName: String = "Some"
62+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,32 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
2323
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.aggregate._
2425
import org.apache.spark.sql.catalyst.plans.logical._
2526
import org.apache.spark.sql.catalyst.rules._
2627
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2728
import org.apache.spark.sql.types._
2829

2930

3031
/**
31-
* Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can
32-
* be evaluated. This is mainly used to provide compatibility with other databases.
33-
* For example, we use this to support "nvl" by replacing it with "coalesce".
32+
* Finds all the expressions that are unevaluable and replace/rewrite them with semantically
33+
* equivalent expressions that can be evaluated. Currently we replace two kinds of expressions:
34+
* 1) [[RuntimeReplaceable]] expressions
35+
* 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any
36+
* This is mainly used to provide compatibility with other databases.
37+
* Few examples are:
38+
* we use this to support "nvl" by replacing it with "coalesce".
39+
* we use this to replace Every and Any with Min and Max respectively.
40+
*
41+
* TODO: In future, explore an option to replace aggregate functions similar to
42+
* how RruntimeReplaceable does.
3443
*/
3544
object ReplaceExpressions extends Rule[LogicalPlan] {
3645
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
3746
case e: RuntimeReplaceable => e.child
47+
case SomeAgg(arg) => Max(arg)
48+
case AnyAgg(arg) => Max(arg)
49+
case EveryAgg(arg) => Min(arg)
3850
}
3951
}
4052

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
144144
assertSuccess(Sum('stringField))
145145
assertSuccess(Average('stringField))
146146
assertSuccess(Min('arrayField))
147+
assertSuccess(new EveryAgg('booleanField))
148+
assertSuccess(new AnyAgg('booleanField))
149+
assertSuccess(new SomeAgg('booleanField))
147150

148151
assertError(Min('mapField), "min does not support ordering on type")
149152
assertError(Max('mapField), "max does not support ordering on type")

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,69 @@ SELECT 1 FROM range(10) HAVING true;
8080
SELECT 1 FROM range(10) HAVING MAX(id) > 0;
8181

8282
SELECT id FROM range(10) HAVING id > 0;
83+
84+
-- Test data
85+
CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES
86+
(1, true), (1, false),
87+
(2, true),
88+
(3, false), (3, null),
89+
(4, null), (4, null),
90+
(5, null), (5, true), (5, false) AS test_agg(k, v);
91+
92+
-- empty table
93+
SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0;
94+
95+
-- all null values
96+
SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4;
97+
98+
-- aggregates are null Filtering
99+
SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5;
100+
101+
-- group by
102+
SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k;
103+
104+
-- having
105+
SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false;
106+
SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL;
107+
108+
-- basic subquery path to make sure rewrite happens in both parent and child plans.
109+
SELECT k,
110+
Every(v) AS every
111+
FROM test_agg
112+
WHERE k = 2
113+
AND v IN (SELECT Any(v)
114+
FROM test_agg
115+
WHERE k = 1)
116+
GROUP BY k;
117+
118+
-- basic subquery path to make sure rewrite happens in both parent and child plans.
119+
SELECT k,
120+
Every(v) AS every
121+
FROM test_agg
122+
WHERE k = 2
123+
AND v IN (SELECT Every(v)
124+
FROM test_agg
125+
WHERE k = 1)
126+
GROUP BY k;
127+
128+
-- input type checking Int
129+
SELECT every(1);
130+
131+
-- input type checking Short
132+
SELECT some(1S);
133+
134+
-- input type checking Long
135+
SELECT any(1L);
136+
137+
-- input type checking String
138+
SELECT every("true");
139+
140+
-- every/some/any aggregates are supported as windows expression.
141+
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
142+
SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
143+
SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
144+
145+
-- simple explain of queries having every/some/any agregates. Optimized
146+
-- plan should show the rewritten aggregate expression.
147+
EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k;
148+

0 commit comments

Comments
 (0)