Skip to content

Commit 28f4da4

Browse files
committed
Python UDF in higher order functions should not throw internal error
1 parent 2ac2710 commit 28f4da4

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,6 +2425,11 @@
24252425
"message" : [
24262426
"A higher order function expects <expectedNumArgs> arguments, but got <actualNumArgs>."
24272427
]
2428+
},
2429+
"UNEVALUABLE" : {
2430+
"message" : [
2431+
"Evaluable expressions should be used for a lambda function in a higher order function. However, <funcName> was unevaluable."
2432+
]
24282433
}
24292434
},
24302435
"sqlState" : "42K0D"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
254254
hof.invalidFormat(checkRes)
255255
}
256256

257+
case e: HigherOrderFunction
258+
if e.resolved && e.functions.exists(_.exists(_.isInstanceOf[Unevaluable])) =>
259+
val u = e.functions.flatMap(_.find(_.isInstanceOf[Unevaluable])).head
260+
e.failAnalysis(
261+
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.UNEVALUABLE",
262+
messageParameters = Map("funcName" -> toSQLExpr(u)))
263+
257264
// If an attribute can't be resolved as a map key of string type, either the key should be
258265
// surrounded with single quotes, or there is a typo in the attribute name.
259266
case GetMapValue(map, key: Attribute) if isMapWithStringKey(map) && !key.resolved =>

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20-
import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest}
21-
import org.apache.spark.sql.functions.count
20+
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest}
21+
import org.apache.spark.sql.functions.{array, count, transform}
2222
import org.apache.spark.sql.test.SharedSparkSession
2323
import org.apache.spark.sql.types.LongType
2424

@@ -112,4 +112,16 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession {
112112
val pandasTestUDF = TestGroupedAggPandasUDF(name = udfName)
113113
assert(df.agg(pandasTestUDF(df("id"))).schema.fieldNames.exists(_.startsWith(udfName)))
114114
}
115+
116+
test("SPARK-48706: Negative test case for Python UDF in higher order functions") {
117+
assume(shouldTestPythonUDFs)
118+
checkError(
119+
exception = intercept[AnalysisException] {
120+
spark.range(1).select(transform(array("id"), x => pythonTestUDF(x))).collect()
121+
},
122+
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.UNEVALUABLE",
123+
parameters = Map("funcName" -> "\"pyUDF(namedlambdavariable())\""),
124+
context = ExpectedContext(
125+
"transform", s".*${this.getClass.getSimpleName}.*"))
126+
}
115127
}

0 commit comments

Comments
 (0)