Skip to content

Commit b793d06

Browse files
committed
Code changes
1 parent e6c5c84 commit b793d06

File tree

6 files changed

+101
-1
lines changed

6 files changed

+101
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,10 @@ object FunctionRegistry {
300300
expression[CollectList]("collect_list"),
301301
expression[CollectSet]("collect_set"),
302302
expression[CountMinSketchAgg]("count_min_sketch"),
303+
expression[EveryAgg]("every"),
304+
expression[AnyAgg]("any"),
305+
expression[SomeAgg]("some"),
306+
303307

304308
// string functions
305309
expression[Ascii]("ascii"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.util.Locale
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
24+
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
2425
import org.apache.spark.sql.catalyst.expressions.codegen._
2526
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2627
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -282,6 +283,31 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable {
282283
override lazy val canonicalized: Expression = child.canonicalized
283284
}
284285

286+
/**
287+
* An aggregate expression that gets rewritten (currently by the optimizer) into a
288+
* different aggregate expression for evaluation. This is mainly used to provide compatibility
289+
* with other databases. For example, we use this to support every, any/some aggregates by rewriting
290+
* them with Min and Max respectively.
291+
*/
292+
trait UnevaluableAggrgate extends DeclarativeAggregate {
293+
294+
override def nullable: Boolean = true
295+
296+
override lazy val aggBufferAttributes =
297+
throw new UnsupportedOperationException(s"Cannot evaluate aggBufferAttributes: $this")
298+
299+
override lazy val initialValues: Seq[Expression] =
300+
throw new UnsupportedOperationException(s"Cannot evaluate initialValues: $this")
301+
302+
override lazy val updateExpressions: Seq[Expression] =
303+
throw new UnsupportedOperationException(s"Cannot evaluate updateExpressions: $this")
304+
305+
override lazy val mergeExpressions: Seq[Expression] =
306+
throw new UnsupportedOperationException(s"Cannot evaluate mergeExpressions: $this")
307+
308+
override lazy val evaluateExpression: Expression =
309+
throw new UnsupportedOperationException(s"Cannot evaluate evaluateExpression: $this")
310+
}
285311

286312
/**
287313
* Expressions that don't have SQL representation should extend this trait. Examples are

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,34 @@ case class Max(child: Expression) extends DeclarativeAggregate {
5757

5858
override lazy val evaluateExpression: AttributeReference = max
5959
}
60+
61+
abstract class AnyAggBase(arg: Expression)
62+
extends UnevaluableAggrgate with ImplicitCastInputTypes {
63+
64+
override def children: Seq[Expression] = arg :: Nil
65+
66+
override def dataType: DataType = BooleanType
67+
68+
override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)
69+
70+
override def checkInputDataTypes(): TypeCheckResult = {
71+
arg.dataType match {
72+
case dt if dt != BooleanType =>
73+
TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " +
74+
s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].")
75+
case _ => TypeCheckResult.TypeCheckSuccess
76+
}
77+
}
78+
}
79+
80+
@ExpressionDescription(
81+
usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.")
82+
case class AnyAgg(arg: Expression) extends AnyAggBase(arg) {
83+
override def nodeName: String = "Any"
84+
}
85+
86+
@ExpressionDescription(
87+
usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.")
88+
case class SomeAgg(arg: Expression) extends AnyAggBase(arg) {
89+
override def nodeName: String = "Some"
90+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,27 @@ case class Min(child: Expression) extends DeclarativeAggregate {
5757

5858
override lazy val evaluateExpression: AttributeReference = min
5959
}
60+
61+
@ExpressionDescription(
62+
usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.")
63+
case class EveryAgg(arg: Expression)
64+
extends UnevaluableAggrgate with ImplicitCastInputTypes {
65+
66+
override def nodeName: String = "Every"
67+
68+
override def children: Seq[Expression] = arg :: Nil
69+
70+
override def dataType: DataType = BooleanType
71+
72+
override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)
73+
74+
override def checkInputDataTypes(): TypeCheckResult = {
75+
arg.dataType match {
76+
case dt if dt != BooleanType =>
77+
TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " +
78+
s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].")
79+
case _ => TypeCheckResult.TypeCheckSuccess
80+
}
81+
}
82+
}
83+

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
118118
ReplaceExpressions,
119119
ComputeCurrentTime,
120120
GetCurrentDatabase(sessionCatalog),
121+
RewriteUnevaluableAggregates,
121122
RewriteDistinctAggregates,
122123
ReplaceDeduplicateWithAggregate) ::
123124
//////////////////////////////////////////////////////////////////////////////////////////
@@ -206,7 +207,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
206207
PullupCorrelatedPredicates.ruleName ::
207208
RewriteCorrelatedScalarSubquery.ruleName ::
208209
RewritePredicateSubquery.ruleName ::
209-
PullOutPythonUDFInJoinCondition.ruleName :: Nil
210+
PullOutPythonUDFInJoinCondition.ruleName ::
211+
RewriteUnevaluableAggregates.ruleName :: Nil
210212

211213
/**
212214
* Optimize all the subqueries inside expression.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
2323
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.aggregate._
2425
import org.apache.spark.sql.catalyst.plans.logical._
2526
import org.apache.spark.sql.catalyst.rules._
2627
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -38,6 +39,18 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
3839
}
3940
}
4041

42+
/**
43+
* Rewrites the aggregates expressions by replacing them with another. This is mainly used to
44+
* provide compatibiity with other databases. For example, we use this to support
45+
* Every, Any/Some by rewriting them to Min, Max respectively.
46+
*/
47+
object RewriteUnevaluableAggregates extends Rule[LogicalPlan] {
48+
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
49+
case SomeAgg(arg) => Max(arg)
50+
case AnyAgg(arg) => Max(arg)
51+
case EveryAgg(arg) => Min(arg)
52+
}
53+
}
4154

4255
/**
4356
* Computes the current date and time to make sure we return the same result in a single query.

0 commit comments

Comments
 (0)