Skip to content

Commit 7d5f6e8

Browse files
committed
[SPARK-26293][SQL] Cast exception when having python udf in subquery
## What changes were proposed in this pull request? This is a regression introduced by #22104 at Spark 2.4.0. When we have Python UDF in subquery, we will hit an exception ``` Caused by: java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.AttributeReference cannot be cast to org.apache.spark.sql.catalyst.expressions.PythonUDF at scala.collection.immutable.Stream.map(Stream.scala:414) at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:98) at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:815) ... ``` #22104 turned `ExtractPythonUDFs` from a physical rule to optimizer rule. However, there is a difference between a physical rule and optimizer rule. A physical rule always runs once, an optimizer rule may be applied twice on a query tree even the rule is located in a batch that only runs once. For a subquery, the `OptimizeSubqueries` rule will execute the entire optimizer on the query plan inside subquery. Later on subquery will be turned to joins, and the optimizer rules will be applied to it again. Unfortunately, the `ExtractPythonUDFs` rule is not idempotent. When it's applied twice on a query plan inside subquery, it will produce a malformed plan. It extracts Python UDF from Python exec plans. This PR proposes 2 changes to be double safe: 1. `ExtractPythonUDFs` should skip python exec plans, to make the rule idempotent 2. `ExtractPythonUDFs` should skip subquery ## How was this patch tested? a new test. Closes #23248 from cloud-fan/python. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent cbe9230 commit 7d5f6e8

File tree

4 files changed

+46
-40
lines changed

4 files changed

+46
-40
lines changed

python/pyspark/sql/tests/test_udf.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from pyspark import SparkContext
2525
from pyspark.sql import SparkSession, Column, Row
26-
from pyspark.sql.functions import UserDefinedFunction
26+
from pyspark.sql.functions import UserDefinedFunction, udf
2727
from pyspark.sql.types import *
2828
from pyspark.sql.utils import AnalysisException
2929
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
@@ -102,7 +102,6 @@ def test_udf_registration_return_type_not_none(self):
102102

103103
def test_nondeterministic_udf(self):
104104
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
105-
from pyspark.sql.functions import udf
106105
import random
107106
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
108107
self.assertEqual(udf_random_col.deterministic, False)
@@ -113,7 +112,6 @@ def test_nondeterministic_udf(self):
113112

114113
def test_nondeterministic_udf2(self):
115114
import random
116-
from pyspark.sql.functions import udf
117115
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
118116
self.assertEqual(random_udf.deterministic, False)
119117
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
@@ -132,7 +130,6 @@ def test_nondeterministic_udf2(self):
132130

133131
def test_nondeterministic_udf3(self):
134132
# regression test for SPARK-23233
135-
from pyspark.sql.functions import udf
136133
f = udf(lambda x: x)
137134
# Here we cache the JVM UDF instance.
138135
self.spark.range(1).select(f("id"))
@@ -144,7 +141,7 @@ def test_nondeterministic_udf3(self):
144141
self.assertFalse(deterministic)
145142

146143
def test_nondeterministic_udf_in_aggregate(self):
147-
from pyspark.sql.functions import udf, sum
144+
from pyspark.sql.functions import sum
148145
import random
149146
udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
150147
df = self.spark.range(10)
@@ -181,7 +178,6 @@ def test_multiple_udfs(self):
181178
self.assertEqual(tuple(row), (6, 5))
182179

183180
def test_udf_in_filter_on_top_of_outer_join(self):
184-
from pyspark.sql.functions import udf
185181
left = self.spark.createDataFrame([Row(a=1)])
186182
right = self.spark.createDataFrame([Row(a=1)])
187183
df = left.join(right, on='a', how='left_outer')
@@ -190,7 +186,6 @@ def test_udf_in_filter_on_top_of_outer_join(self):
190186

191187
def test_udf_in_filter_on_top_of_join(self):
192188
# regression test for SPARK-18589
193-
from pyspark.sql.functions import udf
194189
left = self.spark.createDataFrame([Row(a=1)])
195190
right = self.spark.createDataFrame([Row(b=1)])
196191
f = udf(lambda a, b: a == b, BooleanType())
@@ -199,7 +194,6 @@ def test_udf_in_filter_on_top_of_join(self):
199194

200195
def test_udf_in_join_condition(self):
201196
# regression test for SPARK-25314
202-
from pyspark.sql.functions import udf
203197
left = self.spark.createDataFrame([Row(a=1)])
204198
right = self.spark.createDataFrame([Row(b=1)])
205199
f = udf(lambda a, b: a == b, BooleanType())
@@ -211,7 +205,7 @@ def test_udf_in_join_condition(self):
211205

212206
def test_udf_in_left_outer_join_condition(self):
213207
# regression test for SPARK-26147
214-
from pyspark.sql.functions import udf, col
208+
from pyspark.sql.functions import col
215209
left = self.spark.createDataFrame([Row(a=1)])
216210
right = self.spark.createDataFrame([Row(b=1)])
217211
f = udf(lambda a: str(a), StringType())
@@ -223,7 +217,6 @@ def test_udf_in_left_outer_join_condition(self):
223217

224218
def test_udf_in_left_semi_join_condition(self):
225219
# regression test for SPARK-25314
226-
from pyspark.sql.functions import udf
227220
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
228221
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
229222
f = udf(lambda a, b: a == b, BooleanType())
@@ -236,7 +229,6 @@ def test_udf_in_left_semi_join_condition(self):
236229
def test_udf_and_common_filter_in_join_condition(self):
237230
# regression test for SPARK-25314
238231
# test the complex scenario with both udf and common filter
239-
from pyspark.sql.functions import udf
240232
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
241233
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
242234
f = udf(lambda a, b: a == b, BooleanType())
@@ -247,7 +239,6 @@ def test_udf_and_common_filter_in_join_condition(self):
247239
def test_udf_and_common_filter_in_left_semi_join_condition(self):
248240
# regression test for SPARK-25314
249241
# test the complex scenario with both udf and common filter
250-
from pyspark.sql.functions import udf
251242
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
252243
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
253244
f = udf(lambda a, b: a == b, BooleanType())
@@ -258,7 +249,6 @@ def test_udf_and_common_filter_in_left_semi_join_condition(self):
258249
def test_udf_not_supported_in_join_condition(self):
259250
# regression test for SPARK-25314
260251
# test python udf is not supported in join type besides left_semi and inner join.
261-
from pyspark.sql.functions import udf
262252
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
263253
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
264254
f = udf(lambda a, b: a == b, BooleanType())
@@ -301,7 +291,7 @@ def test_broadcast_in_udf(self):
301291

302292
def test_udf_with_filter_function(self):
303293
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
304-
from pyspark.sql.functions import udf, col
294+
from pyspark.sql.functions import col
305295
from pyspark.sql.types import BooleanType
306296

307297
my_filter = udf(lambda a: a < 2, BooleanType())
@@ -310,7 +300,7 @@ def test_udf_with_filter_function(self):
310300

311301
def test_udf_with_aggregate_function(self):
312302
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
313-
from pyspark.sql.functions import udf, col, sum
303+
from pyspark.sql.functions import col, sum
314304
from pyspark.sql.types import BooleanType
315305

316306
my_filter = udf(lambda a: a == 1, BooleanType())
@@ -326,7 +316,7 @@ def test_udf_with_aggregate_function(self):
326316
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
327317

328318
def test_udf_in_generate(self):
329-
from pyspark.sql.functions import udf, explode
319+
from pyspark.sql.functions import explode
330320
df = self.spark.range(5)
331321
f = udf(lambda x: list(range(x)), ArrayType(LongType()))
332322
row = df.select(explode(f(*df))).groupBy().sum().first()
@@ -353,7 +343,6 @@ def test_udf_in_generate(self):
353343
self.assertEqual(res[3][1], 1)
354344

355345
def test_udf_with_order_by_and_limit(self):
356-
from pyspark.sql.functions import udf
357346
my_copy = udf(lambda x: x, IntegerType())
358347
df = self.spark.range(10).orderBy("id")
359348
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
@@ -394,14 +383,14 @@ def test_non_existed_udaf(self):
394383
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
395384

396385
def test_udf_with_input_file_name(self):
397-
from pyspark.sql.functions import udf, input_file_name
386+
from pyspark.sql.functions import input_file_name
398387
sourceFile = udf(lambda path: path, StringType())
399388
filePath = "python/test_support/sql/people1.json"
400389
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
401390
self.assertTrue(row[0].find("people1.json") != -1)
402391

403392
def test_udf_with_input_file_name_for_hadooprdd(self):
404-
from pyspark.sql.functions import udf, input_file_name
393+
from pyspark.sql.functions import input_file_name
405394

406395
def filename(path):
407396
return path
@@ -427,9 +416,6 @@ def test_udf_defers_judf_initialization(self):
427416
# This is separate of UDFInitializationTests
428417
# to avoid context initialization
429418
# when udf is called
430-
431-
from pyspark.sql.functions import UserDefinedFunction
432-
433419
f = UserDefinedFunction(lambda x: x, StringType())
434420

435421
self.assertIsNone(
@@ -445,8 +431,6 @@ def test_udf_defers_judf_initialization(self):
445431
)
446432

447433
def test_udf_with_string_return_type(self):
448-
from pyspark.sql.functions import UserDefinedFunction
449-
450434
add_one = UserDefinedFunction(lambda x: x + 1, "integer")
451435
make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
452436
make_array = UserDefinedFunction(
@@ -460,13 +444,11 @@ def test_udf_with_string_return_type(self):
460444
self.assertTupleEqual(expected, actual)
461445

462446
def test_udf_shouldnt_accept_noncallable_object(self):
463-
from pyspark.sql.functions import UserDefinedFunction
464-
465447
non_callable = None
466448
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
467449

468450
def test_udf_with_decorator(self):
469-
from pyspark.sql.functions import lit, udf
451+
from pyspark.sql.functions import lit
470452
from pyspark.sql.types import IntegerType, DoubleType
471453

472454
@udf(IntegerType())
@@ -523,7 +505,6 @@ def as_double(x):
523505
)
524506

525507
def test_udf_wrapper(self):
526-
from pyspark.sql.functions import udf
527508
from pyspark.sql.types import IntegerType
528509

529510
def f(x):
@@ -569,7 +550,7 @@ def test_nonparam_udf_with_aggregate(self):
569550
# SPARK-24721
570551
@unittest.skipIf(not test_compiled, test_not_compiled_message)
571552
def test_datasource_with_udf(self):
572-
from pyspark.sql.functions import udf, lit, col
553+
from pyspark.sql.functions import lit, col
573554

574555
path = tempfile.mkdtemp()
575556
shutil.rmtree(path)
@@ -609,8 +590,6 @@ def test_datasource_with_udf(self):
609590

610591
# SPARK-25591
611592
def test_same_accumulator_in_udfs(self):
612-
from pyspark.sql.functions import udf
613-
614593
data_schema = StructType([StructField("a", IntegerType(), True),
615594
StructField("b", IntegerType(), True)])
616595
data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
@@ -632,6 +611,15 @@ def second_udf(x):
632611
data.collect()
633612
self.assertEqual(test_accum.value, 101)
634613

614+
# SPARK-26293
615+
def test_udf_in_subquery(self):
616+
f = udf(lambda x: x, "long")
617+
with self.tempView("v"):
618+
self.spark.range(1).filter(f("id") >= 0).createTempView("v")
619+
sql = self.spark.sql
620+
result = sql("select i from values(0L) as data(i) where i in (select id from v)")
621+
self.assertEqual(result.collect(), [Row(i=0)])
622+
635623

636624
class UDFInitializationTests(unittest.TestCase):
637625
def tearDown(self):
@@ -642,8 +630,6 @@ def tearDown(self):
642630
SparkContext._active_spark_context.stop()
643631

644632
def test_udf_init_shouldnt_initialize_context(self):
645-
from pyspark.sql.functions import UserDefinedFunction
646-
647633
UserDefinedFunction(lambda x: x, StringType())
648634

649635
self.assertIsNone(

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
6060
/**
6161
* A logical plan that evaluates a [[PythonUDF]].
6262
*/
63-
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
64-
extends UnaryNode
63+
case class ArrowEvalPython(
64+
udfs: Seq[PythonUDF],
65+
output: Seq[Attribute],
66+
child: LogicalPlan) extends UnaryNode {
67+
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
68+
}
6569

6670
/**
6771
* A physical plan that evaluates a [[PythonUDF]].

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ import org.apache.spark.sql.types.{StructField, StructType}
3232
/**
3333
* A logical plan that evaluates a [[PythonUDF]]
3434
*/
35-
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
36-
extends UnaryNode
35+
case class BatchEvalPython(
36+
udfs: Seq[PythonUDF],
37+
output: Seq[Attribute],
38+
child: LogicalPlan) extends UnaryNode {
39+
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
40+
}
3741

3842
/**
3943
* A physical plan that evaluates a [[PythonUDF]]

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.api.python.PythonEvalType
2424
import org.apache.spark.sql.AnalysisException
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
27-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
27+
import org.apache.spark.sql.catalyst.plans.logical._
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929

3030

@@ -131,8 +131,20 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
131131
expressions.flatMap(collectEvaluableUDFs)
132132
}
133133

134-
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
135-
case plan: LogicalPlan => extract(plan)
134+
def apply(plan: LogicalPlan): LogicalPlan = plan match {
135+
// SPARK-26293: A subquery will be rewritten into join later, and will go through this rule
136+
// eventually. Here we skip subquery, as Python UDF only needs to be extracted once.
137+
case _: Subquery => plan
138+
139+
case _ => plan transformUp {
140+
// A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and
141+
// `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't
142+
// extract Python UDFs from them.
143+
case p: BatchEvalPython => p
144+
case p: ArrowEvalPython => p
145+
146+
case plan: LogicalPlan => extract(plan)
147+
}
136148
}
137149

138150
/**

0 commit comments

Comments
 (0)