Skip to content

Commit aa352a2

Browse files
panbingkunSandishKumarHN
authored andcommitted
[SPARK-40751][SQL] Migrate type check failures of high order functions onto error classes
### What changes were proposed in this pull request? This pr aims to replace TypeCheckFailure by DataTypeMismatch in type checks in the high-order functions expressions, includes: - 1. ArraySort (2): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L403-L407 - 2. ArrayAggregate (1): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L807 - 3. MapZipWith (1): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L1028 ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages. ### Does this PR introduce _any_ user-facing change? Yes. The PR changes user-facing error messages. ### How was this patch tested? - Update existed UT - Pass GA. Closes apache#38359 from panbingkun/SPARK-40751. Authored-by: panbingkun <pbk1982@gmail.com> Signed-off-by: Max Gekk <max.gekk@gmail.com>
1 parent 8c0b8bd commit aa352a2

File tree

6 files changed

+171
-30
lines changed

6 files changed

+171
-30
lines changed

core/src/main/resources/error/error-classes.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@
200200
"The <functionName> accepts only arrays of pair structs, but <childExpr> is of <childType>."
201201
]
202202
},
203+
"MAP_ZIP_WITH_DIFF_TYPES" : {
204+
"message" : [
205+
"Input to the <functionName> should have been two maps with compatible key types, but it's [<leftType>, <rightType>]."
206+
]
207+
},
203208
"NON_FOLDABLE_INPUT" : {
204209
"message" : [
205210
"the input <inputName> should be a foldable <inputType> expression; however, got <inputExpr>."
@@ -275,6 +280,11 @@
275280
"The <exprName> must not be null"
276281
]
277282
},
283+
"UNEXPECTED_RETURN_TYPE" : {
284+
"message" : [
285+
"The <functionName> requires return <expectedType> type, but the actual is <actualType> type."
286+
]
287+
},
278288
"UNEXPECTED_STATIC_METHOD" : {
279289
"message" : [
280290
"cannot find a static method <methodName> that matches the argument types in <className>"

core/src/test/scala/org/apache/spark/SparkFunSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,12 @@ abstract class SparkFunSuite
370370
checkError(exception, errorClass, sqlState, parameters,
371371
false, Array(context))
372372

373+
protected def checkErrorMatchPVals(
374+
exception: SparkThrowable,
375+
errorClass: String,
376+
parameters: Map[String, String]): Unit =
377+
checkError(exception, errorClass, None, parameters, matchPVals = true)
378+
373379
protected def checkErrorMatchPVals(
374380
exception: SparkThrowable,
375381
errorClass: String,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import scala.collection.mutable
2424

2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
27+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
28+
import org.apache.spark.sql.catalyst.expressions.Cast._
2729
import org.apache.spark.sql.catalyst.expressions.codegen._
2830
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike}
2931
import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -400,11 +402,25 @@ case class ArraySort(
400402
if (function.dataType == IntegerType) {
401403
TypeCheckResult.TypeCheckSuccess
402404
} else {
403-
TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " +
404-
"IntegerType")
405+
DataTypeMismatch(
406+
errorSubClass = "UNEXPECTED_RETURN_TYPE",
407+
messageParameters = Map(
408+
"functionName" -> toSQLId(function.prettyName),
409+
"expectedType" -> toSQLType(IntegerType),
410+
"actualType" -> toSQLType(function.dataType)
411+
)
412+
)
405413
}
406414
case _ =>
407-
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
415+
DataTypeMismatch(
416+
errorSubClass = "UNEXPECTED_INPUT_TYPE",
417+
messageParameters = Map(
418+
"paramIndex" -> "1",
419+
"requiredType" -> toSQLType(ArrayType),
420+
"inputSql" -> toSQLExpr(argument),
421+
"inputType" -> toSQLType(argument.dataType)
422+
)
423+
)
408424
}
409425
case failure => failure
410426
}
@@ -804,9 +820,13 @@ case class ArrayAggregate(
804820
case TypeCheckResult.TypeCheckSuccess =>
805821
if (!DataType.equalsStructurally(
806822
zero.dataType, merge.dataType, ignoreNullability = true)) {
807-
TypeCheckResult.TypeCheckFailure(
808-
s"argument 3 requires ${zero.dataType.simpleString} type, " +
809-
s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.")
823+
DataTypeMismatch(
824+
errorSubClass = "UNEXPECTED_INPUT_TYPE",
825+
messageParameters = Map(
826+
"paramIndex" -> "3",
827+
"requiredType" -> toSQLType(zero.dataType),
828+
"inputSql" -> toSQLExpr(merge),
829+
"inputType" -> toSQLType(merge.dataType)))
810830
} else {
811831
TypeCheckResult.TypeCheckSuccess
812832
}
@@ -1025,9 +1045,14 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
10251045
if (leftKeyType.sameType(rightKeyType)) {
10261046
TypeUtils.checkForOrderingExpr(leftKeyType, prettyName)
10271047
} else {
1028-
TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " +
1029-
s"been two ${MapType.simpleString}s with compatible key types, but the key types are " +
1030-
s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].")
1048+
DataTypeMismatch(
1049+
errorSubClass = "MAP_ZIP_WITH_DIFF_TYPES",
1050+
messageParameters = Map(
1051+
"functionName" -> toSQLId(prettyName),
1052+
"leftType" -> toSQLType(leftKeyType),
1053+
"rightType" -> toSQLType(rightKeyType)
1054+
)
1055+
)
10311056
}
10321057
case failure => failure
10331058
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.{SparkException, SparkFunSuite}
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
22+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
23+
import org.apache.spark.sql.catalyst.expressions.Cast._
2224
import org.apache.spark.sql.internal.SQLConf
2325
import org.apache.spark.sql.types._
2426

@@ -859,4 +861,20 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
859861
Seq(1, 1, 2, 3))
860862
}
861863
}
864+
865+
test("Return type of the given function has to be IntegerType") {
866+
val comparator = {
867+
val comp = ArraySort.comparator _
868+
(left: Expression, right: Expression) => Literal.create("hello", StringType)
869+
}
870+
871+
val result = arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator).checkInputDataTypes()
872+
assert(result == DataTypeMismatch(
873+
errorSubClass = "UNEXPECTED_RETURN_TYPE",
874+
messageParameters = Map(
875+
"functionName" -> toSQLId("lambdafunction"),
876+
"expectedType" -> toSQLType(IntegerType),
877+
"actualType" -> toSQLType(StringType)
878+
)))
879+
}
862880
}

sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,22 @@ FROM various_maps
8282
struct<>
8383
-- !query output
8484
org.apache.spark.sql.AnalysisException
85-
cannot resolve 'map_zip_with(various_maps.decimal_map1, various_maps.decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7
86-
85+
{
86+
"errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
87+
"messageParameters" : {
88+
"functionName" : "`map_zip_with`",
89+
"leftType" : "\"DECIMAL(36,0)\"",
90+
"rightType" : "\"DECIMAL(36,35)\"",
91+
"sqlExpr" : "\"map_zip_with(decimal_map1, decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))\""
92+
},
93+
"queryContext" : [ {
94+
"objectType" : "",
95+
"objectName" : "",
96+
"startIndex" : 8,
97+
"stopIndex" : 81,
98+
"fragment" : "map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2))"
99+
} ]
100+
}
87101

88102
-- !query
89103
SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m
@@ -110,7 +124,22 @@ FROM various_maps
110124
struct<>
111125
-- !query output
112126
org.apache.spark.sql.AnalysisException
113-
cannot resolve 'map_zip_with(various_maps.decimal_map2, various_maps.int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7
127+
{
128+
"errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
129+
"messageParameters" : {
130+
"functionName" : "`map_zip_with`",
131+
"leftType" : "\"DECIMAL(36,35)\"",
132+
"rightType" : "\"INT\"",
133+
"sqlExpr" : "\"map_zip_with(decimal_map2, int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))\""
134+
},
135+
"queryContext" : [ {
136+
"objectType" : "",
137+
"objectName" : "",
138+
"startIndex" : 8,
139+
"stopIndex" : 76,
140+
"fragment" : "map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2))"
141+
} ]
142+
}
114143

115144

116145
-- !query

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
533533
)
534534
}
535535

536+
test("The given function only supports array input") {
537+
val df = Seq(1, 2, 3).toDF("a")
538+
checkErrorMatchPVals(
539+
exception = intercept[AnalysisException] {
540+
df.select(array_sort(col("a"), (x, y) => x - y))
541+
},
542+
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
543+
parameters = Map(
544+
"sqlExpr" -> """"array_sort\(a, lambdafunction\(\(x_\d+ - y_\d+\), x_\d+, y_\d+\)\)"""",
545+
"paramIndex" -> "1",
546+
"requiredType" -> "\"ARRAY\"",
547+
"inputSql" -> "\"a\"",
548+
"inputType" -> "\"INT\""
549+
))
550+
}
551+
536552
test("sort_array/array_sort functions") {
537553
val df = Seq(
538554
(Array[Int](2, 1, 3), Array("b", "c", "a")),
@@ -3492,15 +3508,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
34923508
"requiredType" -> "\"ARRAY\""))
34933509
// scalastyle:on line.size.limit
34943510

3495-
val ex4 = intercept[AnalysisException] {
3496-
df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
3497-
}
3498-
assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type"))
3511+
// scalastyle:off line.size.limit
3512+
checkError(
3513+
exception = intercept[AnalysisException] {
3514+
df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
3515+
},
3516+
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
3517+
parameters = Map(
3518+
"sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""",
3519+
"paramIndex" -> "3",
3520+
"inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"",
3521+
"inputType" -> "\"STRING\"",
3522+
"requiredType" -> "\"INT\""
3523+
))
3524+
// scalastyle:on line.size.limit
34993525

3500-
val ex4a = intercept[AnalysisException] {
3501-
df.select(aggregate(col("s"), lit(0), (acc, x) => x))
3502-
}
3503-
assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type"))
3526+
// scalastyle:off line.size.limit
3527+
checkError(
3528+
exception = intercept[AnalysisException] {
3529+
df.select(aggregate(col("s"), lit(0), (acc, x) => x))
3530+
},
3531+
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
3532+
parameters = Map(
3533+
"sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""",
3534+
"paramIndex" -> "3",
3535+
"inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"",
3536+
"inputType" -> "\"STRING\"",
3537+
"requiredType" -> "\"INT\""
3538+
))
3539+
// scalastyle:on line.size.limit
35043540

35053541
checkError(
35063542
exception =
@@ -3570,17 +3606,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
35703606
}
35713607
assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match"))
35723608

3573-
val ex2 = intercept[AnalysisException] {
3574-
df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))")
3575-
}
3576-
assert(ex2.getMessage.contains("The input to function map_zip_with should have " +
3577-
"been two maps with compatible key types"))
3609+
checkError(
3610+
exception = intercept[AnalysisException] {
3611+
df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))")
3612+
},
3613+
errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
3614+
parameters = Map(
3615+
"sqlExpr" -> "\"map_zip_with(mis, mmi, lambdafunction(concat(x, y, z), x, y, z))\"",
3616+
"functionName" -> "`map_zip_with`",
3617+
"leftType" -> "\"INT\"",
3618+
"rightType" -> "\"MAP<INT, INT>\""),
3619+
context = ExpectedContext(
3620+
fragment = "map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))",
3621+
start = 0,
3622+
stop = 51))
35783623

3579-
val ex2a = intercept[AnalysisException] {
3580-
df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z)))
3581-
}
3582-
assert(ex2a.getMessage.contains("The input to function map_zip_with should have " +
3583-
"been two maps with compatible key types"))
3624+
// scalastyle:off line.size.limit
3625+
checkError(
3626+
exception = intercept[AnalysisException] {
3627+
df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z)))
3628+
},
3629+
errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
3630+
matchPVals = true,
3631+
parameters = Map(
3632+
"sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""",
3633+
"functionName" -> "`map_zip_with`",
3634+
"leftType" -> "\"INT\"",
3635+
"rightType" -> "\"MAP<INT, INT>\""))
3636+
// scalastyle:on line.size.limit
35843637

35853638
checkError(
35863639
exception = intercept[AnalysisException] {

0 commit comments

Comments
 (0)