Skip to content

Commit 9677b59

Browse files
committed
Do not push down filter if it contains Unevaluable expression
1 parent 7f496d2 commit 9677b59

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ trait OperationHelper extends AliasHelper with PredicateHelper {
6767
empty
6868
}
6969

70-
case Filter(condition, child) =>
70+
case Filter(condition, child) if !condition.exists(_.isInstanceOf[Unevaluable]) =>
7171
val (fields, filters, other, aliases) = collectProjectsAndFilters(child, alwaysInline)
7272
// When collecting projects and filters, we effectively push down filters through
7373
// projects. We need to meet the following conditions to do so:
@@ -115,6 +115,8 @@ object PhysicalOperation extends OperationHelper {
115115
val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline)
116116
// If more than 2 filters are collected, they must all be deterministic.
117117
if (filters.length > 1) assert(filters.forall(_.deterministic))
118+
// Unevaluable expressions should not be pushed
119+
assert(filters.forall(!_.exists(_.isInstanceOf[Unevaluable])))
118120
Some((
119121
fields.getOrElse(child.output),
120122
filters.flatMap(splitConjunctivePredicates),

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20-
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest}
21-
import org.apache.spark.sql.functions.{array, count, transform}
20+
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row}
21+
import org.apache.spark.sql.functions.{array, col, count, transform}
2222
import org.apache.spark.sql.test.SharedSparkSession
2323
import org.apache.spark.sql.types.LongType
2424

@@ -124,4 +124,16 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession {
124124
context = ExpectedContext(
125125
"transform", s".*${this.getClass.getSimpleName}.*"))
126126
}
127+
128+
test("SPARK-48666: Python UDF execution against partitioned column") {
129+
assume(shouldTestPythonUDFs)
130+
withTable("t") {
131+
spark.range(1).selectExpr("id AS t", "(id + 1) AS p").write.partitionBy("p").saveAsTable("t")
132+
val table = spark.table("t")
133+
val newTable = table.withColumn("new_column", pythonTestUDF(table("p")))
134+
val df = newTable.as("t1").join(
135+
newTable.as("t2"), col("t1.new_column") === col("t2.new_column"))
136+
checkAnswer(df, Row(0, 1, 1, 0, 1, 1))
137+
}
138+
}
127139
}

0 commit comments

Comments
 (0)