Skip to content

Commit 07cbba6

Browse files
HyukjinKwonyaooqinn
authored andcommitted
[SPARK-48706][PYTHON] Python UDF in higher order functions should not throw internal error
### What changes were proposed in this pull request? This PR fixes the error messages and classes when Python UDFs are used in higher order functions. ### Why are the changes needed? To show the proper user-facing exceptions with error classes. ### Does this PR introduce _any_ user-facing change? Yes, previously it threw internal error such as: ```python from pyspark.sql.functions import transform, udf, col, array spark.range(1).select(transform(array("id"), lambda x: udf(lambda y: y)(x))).collect() ``` Before: ``` py4j.protocol.Py4JJavaError: An error occurred while calling o74.collectToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 15 in stage 0.0 failed 1 times, most recent failure: Lost task 15.0 in stage 0.0 (TID 15) (ip-192-168-123-103.ap-northeast-2.compute.internal executor driver): org.apache.spark.SparkException: [INTERNAL_ERROR] Cannot evaluate expression: <lambda>(lambda x_0#3L)#2 SQLSTATE: XX000 at org.apache.spark.SparkException$.internalError(SparkException.scala:92) at org.apache.spark.SparkException$.internalError(SparkException.scala:96) ``` After: ``` pyspark.errors.exceptions.captured.AnalysisException: [INVALID_LAMBDA_FUNCTION_CALL.UNEVALUABLE] Invalid lambda function call. Python UDFs should be used in a lambda function at a higher order function. However, "<lambda>(lambda x_0#3L)" was a Python UDF. SQLSTATE: 42K0D; Project [transform(array(id#0L), lambdafunction(<lambda>(lambda x_0#3L)#2, lambda x_0#3L, false)) AS transform(array(id), lambdafunction(<lambda>(lambda x_0#3L), namedlambdavariable()))#4] +- Range (0, 1, step=1, splits=Some(16)) ``` ### How was this patch tested? Unittest was added ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47079 from HyukjinKwon/SPARK-48706. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Kent Yao <yao@apache.org>
1 parent 169346c commit 07cbba6

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4482,6 +4482,11 @@
44824482
"INSERT INTO <tableName> with IF NOT EXISTS in the PARTITION spec."
44834483
]
44844484
},
4485+
"LAMBDA_FUNCTION_WITH_PYTHON_UDF" : {
4486+
"message" : [
4487+
"Lambda function with Python UDF <funcName> in a higher order function."
4488+
]
4489+
},
44854490
"LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC" : {
44864491
"message" : [
44874492
"Referencing a lateral column alias <lca> in the aggregate function <aggFunc>."

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
254254
hof.invalidFormat(checkRes)
255255
}
256256

257+
case hof: HigherOrderFunction
258+
if hof.resolved && hof.functions
259+
.exists(_.exists(_.isInstanceOf[PythonUDF])) =>
260+
val u = hof.functions.flatMap(_.find(_.isInstanceOf[PythonUDF])).head
261+
hof.failAnalysis(
262+
errorClass = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_PYTHON_UDF",
263+
messageParameters = Map("funcName" -> toSQLExpr(u)))
264+
257265
// If an attribute can't be resolved as a map key of string type, either the key should be
258266
// surrounded with single quotes, or there is a typo in the attribute name.
259267
case GetMapValue(map, key: Attribute) if isMapWithStringKey(map) && !key.resolved =>

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20-
import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest}
21-
import org.apache.spark.sql.functions.count
20+
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest}
21+
import org.apache.spark.sql.functions.{array, count, transform}
2222
import org.apache.spark.sql.test.SharedSparkSession
2323
import org.apache.spark.sql.types.LongType
2424

@@ -112,4 +112,16 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession {
112112
val pandasTestUDF = TestGroupedAggPandasUDF(name = udfName)
113113
assert(df.agg(pandasTestUDF(df("id"))).schema.fieldNames.exists(_.startsWith(udfName)))
114114
}
115+
116+
test("SPARK-48706: Negative test case for Python UDF in higher order functions") {
117+
assume(shouldTestPythonUDFs)
118+
checkError(
119+
exception = intercept[AnalysisException] {
120+
spark.range(1).select(transform(array("id"), x => pythonTestUDF(x))).collect()
121+
},
122+
errorClass = "UNSUPPORTED_FEATURE.LAMBDA_FUNCTION_WITH_PYTHON_UDF",
123+
parameters = Map("funcName" -> "\"pyUDF(namedlambdavariable())\""),
124+
context = ExpectedContext(
125+
"transform", s".*${this.getClass.getSimpleName}.*"))
126+
}
115127
}

0 commit comments

Comments
 (0)