Skip to content

Commit 32e5f39

Browse files
HyukjinKwonweiatwork
authored andcommitted
[SPARK-48666][SQL] Do not push down filter if it contains PythonUDFs
### What changes were proposed in this pull request? This PR proposes to prevent pushing down Python UDFs. This PR uses the same approach as apache#47033, therefore added the author as a co-author, but simplifies the change. Extracting filters to push down happens first https://github.com/apache/spark/blob/cbe6846c477bc8b6d94385ddd0097c4e97b05d41/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala#L46 https://github.com/apache/spark/blob/cbe6846c477bc8b6d94385ddd0097c4e97b05d41/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L211 https://github.com/apache/spark/blob/cbe6846c477bc8b6d94385ddd0097c4e97b05d41/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala#L51 Before extracting Python UDFs https://github.com/apache/spark/blob/cbe6846c477bc8b6d94385ddd0097c4e97b05d41/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala#L80 Here is full stacktrace: ``` [INTERNAL_ERROR] Cannot evaluate expression: pyUDF(cast(input[0, bigint, true] as string)) SQLSTATE: XX000 org.apache.spark.SparkException: [INTERNAL_ERROR] Cannot evaluate expression: pyUDF(cast(input[0, bigint, true] as string)) SQLSTATE: XX000 at org.apache.spark.SparkException$.internalError(SparkException.scala:92) at org.apache.spark.SparkException$.internalError(SparkException.scala:96) at org.apache.spark.sql.errors.QueryExecutionErrors$.cannotEvaluateExpressionError(QueryExecutionErrors.scala:65) at org.apache.spark.sql.catalyst.expressions.FoldableUnevaluable.eval(Expression.scala:387) at org.apache.spark.sql.catalyst.expressions.FoldableUnevaluable.eval$(Expression.scala:386) at org.apache.spark.sql.catalyst.expressions.PythonUDF.eval(PythonUDF.scala:72) at org.apache.spark.sql.catalyst.expressions.UnaryExpression.eval(Expression.scala:563) at org.apache.spark.sql.catalyst.expressions.IsNotNull.eval(nullExpressions.scala:403) at org.apache.spark.sql.catalyst.expressions.InterpretedPredicate.eval(predicates.scala:53) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.$anonfun$prunePartitionsByFilter$1(ExternalCatalogUtils.scala:189) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.$anonfun$prunePartitionsByFilter$1$adapted(ExternalCatalogUtils.scala:188) at scala.collection.immutable.List.filter(List.scala:516) at scala.collection.immutable.List.filter(List.scala:79) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.prunePartitionsByFilter(ExternalCatalogUtils.scala:188) at org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.listPartitionsByFilter(InMemoryCatalog.scala:604) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener.listPartitionsByFilter(ExternalCatalogWithListener.scala:262) at org.apache.spark.sql.catalyst.catalog.SessionCatalog.listPartitionsByFilter(SessionCatalog.scala:1358) at org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.listPartitionsByFilter(ExternalCatalogUtils.scala:168) at org.apache.spark.sql.execution.datasources.CatalogFileIndex.filterPartitions(CatalogFileIndex.scala:74) at org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions$$anonfun$apply$1.applyOrElse(PruneFileSourcePartitions.scala:72) at org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions$$anonfun$apply$1.applyOrElse(PruneFileSourcePartitions.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:470) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:84) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:470) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:37) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:330) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:326) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:475) at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1251) at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1250) at org.apache.spark.sql.catalyst.plans.logical.Join.mapChildren(basicLogicalOperators.scala:552) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:475) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:37) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:330) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:326) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:446) at org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions$.apply(PruneFileSourcePartitions.scala:50) at org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions$.apply(PruneFileSourcePartitions.scala:35) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$2(RuleExecutor.scala:226) at scala.collection.LinearSeqOps.foldLeft(LinearSeq.scala:183) at scala.collection.LinearSeqOps.foldLeft$(LinearSeq.scala:179) at scala.collection.immutable.List.foldLeft(List.scala:79) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:223) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:215) at scala.collection.immutable.List.foreach(List.scala:334) at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:215) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:186) at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:89) at org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:186) at org.apache.spark.sql.execution.QueryExecution.$anonfun$optimizedPlan$1(QueryExecution.scala:167) at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:138) at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:234) at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:608) at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:234) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:923) at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:233) at org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:163) at org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:159) at org.apache.spark.sql.execution.python.PythonUDFSuite.$anonfun$new$19(PythonUDFSuite.scala:136) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.scala:18) at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64) at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99) at org.apache.spark.sql.test.SQLTestUtilsBase.withTable(SQLTestUtils.scala:307) at org.apache.spark.sql.test.SQLTestUtilsBase.withTable$(SQLTestUtils.scala:305) at org.apache.spark.sql.execution.python.PythonUDFSuite.withTable(PythonUDFSuite.scala:25) at org.apache.spark.sql.execution.python.PythonUDFSuite.$anonfun$new$18(PythonUDFSuite.scala:130) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.scala:18) at org.scalatest.enablers.Timed$$anon$1.timeoutAfter(Timed.scala:127) at org.scalatest.concurrent.TimeLimits$.failAfterImpl(TimeLimits.scala:282) at org.scalatest.concurrent.TimeLimits.failAfter(TimeLimits.scala:231) at org.scalatest.concurrent.TimeLimits.failAfter$(TimeLimits.scala:230) at org.apache.spark.SparkFunSuite.failAfter(SparkFunSuite.scala:69) at org.apache.spark.SparkFunSuite.$anonfun$test$2(SparkFunSuite.scala:155) at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:226) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:227) at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:224) at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:236) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:236) at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:218) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:69) at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234) at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227) at org.apache.spark.SparkFunSuite.runTest(SparkFunSuite.scala:69) at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:269) at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413) at scala.collection.immutable.List.foreach(List.scala:334) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475) at org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:269) at org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:268) at org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1564) at org.scalatest.Suite.run(Suite.scala:1114) at org.scalatest.Suite.run$(Suite.scala:1096) at org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1564) at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:273) at org.scalatest.SuperEngine.runImpl(Engine.scala:535) at org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:273) at org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:272) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:69) at org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213) at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210) at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208) at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:69) at org.scalatest.tools.SuiteRunner.run(SuiteRunner.scala:47) at org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13(Runner.scala:1321) at org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13$adapted(Runner.scala:1315) at scala.collection.immutable.List.foreach(List.scala:334) at org.scalatest.tools.Runner$.doRunRunRunDaDoRunRun(Runner.scala:1315) at org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24(Runner.scala:992) at org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24$adapted(Runner.scala:970) at org.scalatest.tools.Runner$.withClassLoaderAndDispatchReporter(Runner.scala:1481) at org.scalatest.tools.Runner$.runOptionallyWithPassFailReporter(Runner.scala:970) at org.scalatest.tools.Runner$.run(Runner.scala:798) at org.scalatest.tools.Runner.run(Runner.scala) at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.runScalaTest2or3(ScalaTestRunner.java:43) at org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.main(ScalaTestRunner.java:26) ``` ### Why are the changes needed? In order for end users to use Python UDFs against partitioned columns. ### Does this PR introduce _any_ user-facing change? Yes, this fixes a bug - this PR allows to use Python UDF in partitioned columns. ### How was this patch tested? Unittest added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#47033 Closes apache#47313 from HyukjinKwon/SPARK-48666. Lead-authored-by: Hyukjin Kwon <gurwls223@apache.org> Co-authored-by: Wei Zheng <weiz@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 4927f63 commit 32e5f39

File tree

4 files changed

+32
-7
lines changed

4 files changed

+32
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
6363
_))
6464
if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty =>
6565
val normalizedFilters = DataSourceStrategy.normalizeExprs(
66-
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)),
66+
filters.filter { f =>
67+
f.deterministic &&
68+
!SubqueryExpression.hasSubquery(f) &&
69+
// Python UDFs might exist because this rule is applied before ``ExtractPythonUDFs``.
70+
!f.exists(_.isInstanceOf[PythonUDF])
71+
},
6772
logicalRelation.output)
6873
val (partitionKeyFilters, _) = DataSourceUtils
6974
.getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
1919
import scala.collection.mutable
2020

2121
import org.apache.spark.sql.{sources, SparkSession}
22-
import org.apache.spark.sql.catalyst.expressions.Expression
22+
import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF, SubqueryExpression}
2323
import org.apache.spark.sql.connector.expressions.filter.Predicate
2424
import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns}
2525
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils}
@@ -73,7 +73,10 @@ abstract class FileScanBuilder(
7373
val (deterministicFilters, nonDeterminsticFilters) = filters.partition(_.deterministic)
7474
val (partitionFilters, dataFilters) =
7575
DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, deterministicFilters)
76-
this.partitionFilters = partitionFilters
76+
this.partitionFilters = partitionFilters.filter { f =>
77+
// Python UDFs might exist because this rule is applied before ``ExtractPythonUDFs``.
78+
!SubqueryExpression.hasSubquery(f) && !f.exists(_.isInstanceOf[PythonUDF])
79+
}
7780
this.dataFilters = dataFilters
7881
val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter]
7982
for (filterExpr <- dataFilters) {

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20-
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest}
21-
import org.apache.spark.sql.functions.{array, count, transform}
20+
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row}
21+
import org.apache.spark.sql.functions.{array, col, count, transform}
2222
import org.apache.spark.sql.test.SharedSparkSession
2323
import org.apache.spark.sql.types.LongType
2424

@@ -124,4 +124,16 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession {
124124
context = ExpectedContext(
125125
"transform", s".*${this.getClass.getSimpleName}.*"))
126126
}
127+
128+
test("SPARK-48666: Python UDF execution against partitioned column") {
129+
assume(shouldTestPythonUDFs)
130+
withTable("t") {
131+
spark.range(1).selectExpr("id AS t", "(id + 1) AS p").write.partitionBy("p").saveAsTable("t")
132+
val table = spark.table("t")
133+
val newTable = table.withColumn("new_column", pythonTestUDF(table("p")))
134+
val df = newTable.as("t1").join(
135+
newTable.as("t2"), col("t1.new_column") === col("t2.new_column"))
136+
checkAnswer(df, Row(0, 1, 1, 0, 1, 1))
137+
}
138+
}
127139
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst
2222
import org.apache.spark.sql.SparkSession
2323
import org.apache.spark.sql.catalyst.analysis.CastSupport
2424
import org.apache.spark.sql.catalyst.catalog._
25-
import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
25+
import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, PythonUDF, SubqueryExpression}
2626
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
2727
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
2828
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation
@@ -50,7 +50,12 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession)
5050
filters: Seq[Expression],
5151
relation: HiveTableRelation): ExpressionSet = {
5252
val normalizedFilters = DataSourceStrategy.normalizeExprs(
53-
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output)
53+
filters.filter { f =>
54+
f.deterministic &&
55+
!SubqueryExpression.hasSubquery(f) &&
56+
// Python UDFs might exist because this rule is applied before ``ExtractPythonUDFs``.
57+
!f.exists(_.isInstanceOf[PythonUDF])
58+
}, relation.output)
5459
val partitionColumnSet = AttributeSet(relation.partitionCols)
5560
ExpressionSet(
5661
normalizedFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionColumnSet)))

0 commit comments

Comments
 (0)