Skip to content

Commit b804ca5

Browse files
ueshincloud-fan
authored andcommitted
[SPARK-23908][SQL][FOLLOW-UP] Rename inputs to arguments, and add argument type check.
## What changes were proposed in this pull request? This is a follow-up pr of #21954 to address comments. - Rename ambiguous name `inputs` to `arguments`. - Add argument type check and remove hacky workaround. - Address other small comments. ## How was this patch tested? Existing tests and some additional tests. Closes #22075 from ueshin/issues/SPARK-23908/fup1. Authored-by: Takuya UESHIN <ueshin@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 2e3abdf commit b804ca5

File tree

6 files changed

+152
-98
lines changed

6 files changed

+152
-98
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,20 @@ trait CheckAnalysis extends PredicateHelper {
9090
u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}")
9191

9292
case operator: LogicalPlan =>
93+
// Check argument data types of higher-order functions downwards first.
94+
// If the arguments of the higher-order functions are resolved but the type check fails,
95+
// the argument functions will not get resolved, but we should report the argument type
96+
// check failure instead of claiming the argument functions are unresolved.
97+
operator transformExpressionsDown {
98+
case hof: HigherOrderFunction
99+
if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure =>
100+
hof.checkArgumentDataTypes() match {
101+
case TypeCheckResult.TypeCheckFailure(message) =>
102+
hof.failAnalysis(
103+
s"cannot resolve '${hof.sql}' due to argument data type mismatch: $message")
104+
}
105+
}
106+
93107
operator transformExpressionsUp {
94108
case a: Attribute if !a.resolved =>
95109
val from = operator.inputSet.map(_.qualifiedName).mkString(", ")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,23 +95,23 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
9595
*/
9696
private def createLambda(
9797
e: Expression,
98-
partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e match {
98+
argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
9999
case f: LambdaFunction if f.bound => f
100100

101101
case LambdaFunction(function, names, _) =>
102-
if (names.size != partialArguments.size) {
102+
if (names.size != argInfo.size) {
103103
e.failAnalysis(
104104
s"The number of lambda function arguments '${names.size}' does not " +
105105
"match the number of arguments expected by the higher order function " +
106-
s"'${partialArguments.size}'.")
106+
s"'${argInfo.size}'.")
107107
}
108108

109109
if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) {
110110
e.failAnalysis(
111111
"Lambda function arguments should not have names that are semantically the same.")
112112
}
113113

114-
val arguments = partialArguments.zip(names).map {
114+
val arguments = argInfo.zip(names).map {
115115
case ((dataType, nullable), ne) =>
116116
NamedLambdaVariable(ne.name, dataType, nullable)
117117
}
@@ -122,7 +122,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
122122
// create a lambda function with default parameters because this is expected by the higher
123123
// order function. Note that we hide the lambda variables produced by this function in order
124124
// to prevent accidental naming collisions.
125-
val arguments = partialArguments.zipWithIndex.map {
125+
val arguments = argInfo.zipWithIndex.map {
126126
case ((dataType, nullable), i) =>
127127
NamedLambdaVariable(s"col$i", dataType, nullable)
128128
}
@@ -135,7 +135,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
135135
private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match {
136136
case _ if e.resolved => e
137137

138-
case h: HigherOrderFunction if h.inputResolved =>
138+
case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess =>
139139
h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap))
140140

141141
case l: LambdaFunction if !l.bound =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,19 @@ trait ExpectsInputTypes extends Expression {
4141
def inputTypes: Seq[AbstractDataType]
4242

4343
override def checkInputDataTypes(): TypeCheckResult = {
44-
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
45-
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
44+
ExpectsInputTypes.checkInputDataTypes(children, inputTypes)
45+
}
46+
}
47+
48+
object ExpectsInputTypes {
49+
50+
def checkInputDataTypes(
51+
inputs: Seq[Expression],
52+
inputTypes: Seq[AbstractDataType]): TypeCheckResult = {
53+
val mismatches = inputs.zip(inputTypes).zipWithIndex.collect {
54+
case ((input, expected), idx) if !expected.acceptsType(input.dataType) =>
4655
s"argument ${idx + 1} requires ${expected.simpleString} type, " +
47-
s"however, '${child.sql}' is of ${child.dataType.catalogString} type."
56+
s"however, '${input.sql}' is of ${input.dataType.catalogString} type."
4857
}
4958

5059
if (mismatches.isEmpty) {
@@ -55,7 +64,6 @@ trait ExpectsInputTypes extends Expression {
5564
}
5665
}
5766

58-
5967
/**
6068
* A mixin for the analyzer to perform implicit type casting using
6169
* [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]].

0 commit comments

Comments
 (0)