Skip to content

Commit 22e833d

Browse files
committed
Merge commit '9af338cd685bce26abbc2dd4d077bde5068157b1' into SPARK-34079-multi-column-scalar-subquery
# Conflicts: # sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
2 parents 0cff7b2 + 9af338c commit 22e833d

File tree

14 files changed

+120
-30
lines changed

14 files changed

+120
-30
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvable
2727
import org.apache.spark.sql.catalyst.expressions.codegen._
2828
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2929
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
30+
import org.apache.spark.sql.catalyst.trees.TreePattern.{CAST, TreePattern}
3031
import org.apache.spark.sql.catalyst.util._
3132
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
3233
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@@ -1800,6 +1801,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
18001801
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
18011802
copy(timeZoneId = Option(timeZoneId))
18021803

1804+
final override val nodePatterns: Seq[TreePattern] = Seq(CAST)
1805+
18031806
override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled
18041807

18051808
override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT, TreePattern}
2324
import org.apache.spark.sql.internal.SQLConf
2425
import org.apache.spark.sql.types._
2526

@@ -48,6 +49,8 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
4849

4950
override def nullable: Boolean = false
5051

52+
final override val nodePatterns: Seq[TreePattern] = Seq(COUNT)
53+
5154
// Return data type.
5255
override def dataType: DataType = LongType
5356

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
24+
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern,
25+
UNARY_POSITIVE}
2426
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
2527
import org.apache.spark.sql.errors.QueryExecutionErrors
2628
import org.apache.spark.sql.internal.SQLConf
@@ -128,6 +130,8 @@ case class UnaryPositive(child: Expression)
128130

129131
override def dataType: DataType = child.dataType
130132

133+
final override val nodePatterns: Seq[TreePattern] = Seq(UNARY_POSITIVE)
134+
131135
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
132136
defineCodeGen(ctx, ev, c => c)
133137

@@ -199,6 +203,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
199203

200204
override def dataType: DataType = left.dataType
201205

206+
final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC)
207+
202208
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
203209

204210
/** Name of the function for this expression on a [[Decimal]] type. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
2727
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
2828
import org.apache.spark.sql.catalyst.expressions.codegen._
2929
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
30+
import org.apache.spark.sql.catalyst.trees.TreePattern.{CONCAT, TreePattern}
3031
import org.apache.spark.sql.catalyst.util._
3132
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
3233
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@@ -2172,6 +2173,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
21722173

21732174
private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)
21742175

2176+
final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)
2177+
21752178
override def checkInputDataTypes(): TypeCheckResult = {
21762179
if (children.isEmpty) {
21772180
TypeCheckResult.TypeCheckSuccess

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2424
import org.apache.spark.sql.catalyst.trees.TernaryLike
25+
import org.apache.spark.sql.catalyst.trees.TreePattern.{CASE_WHEN, IF, TreePattern}
2526
import org.apache.spark.sql.types._
2627

2728
// scalastyle:off line.size.limit
@@ -48,6 +49,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
4849
override def third: Expression = falseValue
4950
override def nullable: Boolean = trueValue.nullable || falseValue.nullable
5051

52+
final override val nodePatterns : Seq[TreePattern] = Seq(IF)
53+
5154
override def checkInputDataTypes(): TypeCheckResult = {
5255
if (predicate.dataType != BooleanType) {
5356
TypeCheckResult.TypeCheckFailure(
@@ -139,6 +142,8 @@ case class CaseWhen(
139142

140143
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
141144

145+
final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)
146+
142147
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
143148
super.legacyWithNewChildren(newChildren)
144149

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
24+
import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
2425
import org.apache.spark.sql.catalyst.util.TypeUtils
2526
import org.apache.spark.sql.types._
2627

@@ -345,6 +346,8 @@ case class NaNvl(left: Expression, right: Expression)
345346
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
346347
override def nullable: Boolean = false
347348

349+
final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
350+
348351
override def eval(input: InternalRow): Any = {
349352
child.eval(input) == null
350353
}
@@ -375,6 +378,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
375378
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
376379
override def nullable: Boolean = false
377380

381+
final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
382+
378383
override def eval(input: InternalRow): Any = {
379384
child.eval(input) != null
380385
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
3333
import org.apache.spark.sql.catalyst.expressions.codegen._
3434
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
3535
import org.apache.spark.sql.catalyst.trees.TernaryLike
36+
import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
3637
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
3738
import org.apache.spark.sql.errors.QueryExecutionErrors
3839
import org.apache.spark.sql.types._
@@ -1705,6 +1706,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
17051706
override def foldable: Boolean = false
17061707
override def nullable: Boolean = false
17071708

1709+
final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)
1710+
17081711
override def flatArguments: Iterator[Any] = Iterator(child)
17091712

17101713
private val errMsg = "Null value appeared in non-nullable field:" +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2828
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
29-
import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, IN_SUBQUERY, INSET, TreePattern}
29+
import org.apache.spark.sql.catalyst.trees.TreePattern._
3030
import org.apache.spark.sql.catalyst.util.TypeUtils
3131
import org.apache.spark.sql.internal.SQLConf
3232
import org.apache.spark.sql.types._
@@ -309,6 +309,8 @@ case class Not(child: Expression)
309309

310310
override def inputTypes: Seq[DataType] = Seq(BooleanType)
311311

312+
final override val nodePatterns: Seq[TreePattern] = Seq(NOT)
313+
312314
// +---------+-----------+
313315
// | CHILD | NOT CHILD |
314316
// +---------+-----------+
@@ -435,7 +437,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
435437
override def nullable: Boolean = children.exists(_.nullable)
436438
override def foldable: Boolean = children.forall(_.foldable)
437439

438-
override val nodePatterns: Seq[TreePattern] = Seq(IN)
440+
final override val nodePatterns: Seq[TreePattern] = Seq(IN)
439441

440442
override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
441443

@@ -548,7 +550,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
548550

549551
override def nullable: Boolean = child.nullable || hasNull
550552

551-
override val nodePatterns: Seq[TreePattern] = Seq(INSET)
553+
final override val nodePatterns: Seq[TreePattern] = Seq(INSET)
552554

553555
protected override def nullSafeEval(value: Any): Any = {
554556
if (set.contains(value)) {
@@ -666,6 +668,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
666668

667669
override def sqlOperator: String = "AND"
668670

671+
final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)
672+
669673
// +---------+---------+---------+---------+
670674
// | AND | TRUE | FALSE | UNKNOWN |
671675
// +---------+---------+---------+---------+
@@ -752,6 +756,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
752756

753757
override def sqlOperator: String = "OR"
754758

759+
final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)
760+
755761
// +---------+---------+---------+---------+
756762
// | OR | TRUE | FALSE | UNKNOWN |
757763
// +---------+---------+---------+---------+
@@ -823,6 +829,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
823829
// finitely enumerable. The allowable types are checked below by checkInputDataTypes.
824830
override def inputType: AbstractDataType = AnyDataType
825831

832+
final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_COMPARISON)
833+
826834
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
827835
case TypeCheckResult.TypeCheckSuccess =>
828836
TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
3030
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
3131
import org.apache.spark.sql.catalyst.expressions.codegen._
3232
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
33+
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern}
3334
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
3435
import org.apache.spark.sql.errors.QueryExecutionErrors
3536
import org.apache.spark.sql.types._
@@ -129,6 +130,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
129130

130131
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
131132

133+
final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)
134+
132135
override def toString: String = escapeChar match {
133136
case '\\' => s"$left LIKE $right"
134137
case c => s"$left LIKE $right ESCAPE '$c'"
@@ -198,6 +201,8 @@ sealed abstract class MultiLikeBase
198201

199202
override def nullable: Boolean = true
200203

204+
final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)
205+
201206
protected lazy val hasNull: Boolean = patterns.contains(null)
202207

203208
protected lazy val cache = patterns.filterNot(_ == null)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
3131
import org.apache.spark.sql.catalyst.expressions.codegen._
3232
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
33+
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
3334
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
3435
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
3536
import org.apache.spark.sql.internal.SQLConf
@@ -406,6 +407,8 @@ case class Upper(child: Expression)
406407
override def convert(v: UTF8String): UTF8String = v.toUpperCase
407408
// scalastyle:on caselocale
408409

410+
final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
411+
409412
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
410413
defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
411414
}
@@ -432,6 +435,8 @@ case class Lower(child: Expression)
432435
override def convert(v: UTF8String): UTF8String = v.toLowerCase
433436
// scalastyle:on caselocale
434437

438+
final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)
439+
435440
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
436441
defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
437442
}

0 commit comments

Comments
 (0)