Skip to content

Commit affe809

Browse files
committed
[SPARK-26147][SQL] only pull out unevaluable python udf from join condition
## What changes were proposed in this pull request? #22326 made a mistake that, not all python UDFs are unevaluable in join condition. Only python UDFs that refer to attributes from both join side are unevaluable. This PR fixes this mistake. ## How was this patch tested? a new test Closes #23153 from cloud-fan/join. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 438f8fd commit affe809

File tree

3 files changed

+106
-48
lines changed

3 files changed

+106
-48
lines changed

python/pyspark/sql/tests/test_udf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,18 @@ def test_udf_in_join_condition(self):
209209
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
210210
self.assertEqual(df.collect(), [Row(a=1, b=1)])
211211

212+
def test_udf_in_left_outer_join_condition(self):
213+
# regression test for SPARK-26147
214+
from pyspark.sql.functions import udf, col
215+
left = self.spark.createDataFrame([Row(a=1)])
216+
right = self.spark.createDataFrame([Row(b=1)])
217+
f = udf(lambda a: str(a), StringType())
218+
# The join condition can't be pushed down, as it refers to attributes from both sides.
219+
# The Python UDF only refer to attributes from one side, so it's evaluable.
220+
df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
221+
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
222+
self.assertEqual(df.collect(), [Row(a=1, b=1)])
223+
212224
def test_udf_in_left_semi_join_condition(self):
213225
# regression test for SPARK-25314
214226
from pyspark.sql.functions import udf

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
155155
}
156156

157157
/**
158-
* PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
159-
* and pull them out from join condition. For python udf accessing attributes from only one side,
160-
* they are pushed down by operation push down rules. If not (e.g. user disables filter push
161-
* down rules), we need to pull them out in this rule too.
158+
* PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides.
159+
* See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them
160+
* out from join condition.
162161
*/
163162
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
164-
def hasPythonUDF(expression: Expression): Boolean = {
165-
expression.collectFirst { case udf: PythonUDF => udf }.isDefined
163+
164+
private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = {
165+
expr.find { e =>
166+
PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right)
167+
}.isDefined
166168
}
167169

168170
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
169-
case j @ Join(_, _, joinType, condition)
170-
if condition.isDefined && hasPythonUDF(condition.get) =>
171+
case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) =>
171172
if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
172173
// The current strategy only support InnerLike and LeftSemi join because for other type,
173174
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
@@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH
179180
}
180181
// If condition expression contains python udf, it will be moved out from
181182
// the new join conditions.
182-
val (udf, rest) =
183-
splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
183+
val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j))
184184
val newCondition = if (rest.isEmpty) {
185-
logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
185+
logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," +
186186
s" it will be moved out and the join plan will be turned to cross join.")
187187
None
188188
} else {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.scalatest.Matchers._
21-
2220
import org.apache.spark.api.python.PythonEvalType
2321
import org.apache.spark.sql.AnalysisException
2422
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -28,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
2826
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
2927
import org.apache.spark.sql.catalyst.rules.RuleExecutor
3028
import org.apache.spark.sql.internal.SQLConf._
31-
import org.apache.spark.sql.types.BooleanType
29+
import org.apache.spark.sql.types.{BooleanType, IntegerType}
3230

3331
class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
3432

@@ -40,13 +38,29 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
4038
CheckCartesianProducts) :: Nil
4139
}
4240

43-
val testRelationLeft = LocalRelation('a.int, 'b.int)
44-
val testRelationRight = LocalRelation('c.int, 'd.int)
41+
val attrA = 'a.int
42+
val attrB = 'b.int
43+
val attrC = 'c.int
44+
val attrD = 'd.int
45+
46+
val testRelationLeft = LocalRelation(attrA, attrB)
47+
val testRelationRight = LocalRelation(attrC, attrD)
48+
49+
// This join condition refers to attributes from 2 tables, but the PythonUDF inside it only
50+
// refer to attributes from one side.
51+
val evaluableJoinCond = {
52+
val pythonUDF = PythonUDF("evaluable", null,
53+
IntegerType,
54+
Seq(attrA),
55+
PythonEvalType.SQL_BATCHED_UDF,
56+
udfDeterministic = true)
57+
pythonUDF === attrC
58+
}
4559

46-
// Dummy python UDF for testing. Unable to execute.
47-
val pythonUDF = PythonUDF("pythonUDF", null,
60+
// This join condition is a PythonUDF which refers to attributes from 2 tables.
61+
val unevaluableJoinCond = PythonUDF("unevaluable", null,
4862
BooleanType,
49-
Seq.empty,
63+
Seq(attrA, attrC),
5064
PythonEvalType.SQL_BATCHED_UDF,
5165
udfDeterministic = true)
5266

@@ -66,62 +80,76 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
6680
}
6781
}
6882

69-
test("inner join condition with python udf only") {
70-
val query = testRelationLeft.join(
83+
test("inner join condition with python udf") {
84+
val query1 = testRelationLeft.join(
7185
testRelationRight,
7286
joinType = Inner,
73-
condition = Some(pythonUDF))
74-
val expected = testRelationLeft.join(
87+
condition = Some(unevaluableJoinCond))
88+
val expected1 = testRelationLeft.join(
7589
testRelationRight,
7690
joinType = Inner,
77-
condition = None).where(pythonUDF).analyze
78-
comparePlanWithCrossJoinEnable(query, expected)
91+
condition = None).where(unevaluableJoinCond).analyze
92+
comparePlanWithCrossJoinEnable(query1, expected1)
93+
94+
// evaluable PythonUDF will not be touched
95+
val query2 = testRelationLeft.join(
96+
testRelationRight,
97+
joinType = Inner,
98+
condition = Some(evaluableJoinCond))
99+
comparePlans(Optimize.execute(query2), query2)
79100
}
80101

81-
test("left semi join condition with python udf only") {
82-
val query = testRelationLeft.join(
102+
test("left semi join condition with python udf") {
103+
val query1 = testRelationLeft.join(
83104
testRelationRight,
84105
joinType = LeftSemi,
85-
condition = Some(pythonUDF))
86-
val expected = testRelationLeft.join(
106+
condition = Some(unevaluableJoinCond))
107+
val expected1 = testRelationLeft.join(
87108
testRelationRight,
88109
joinType = Inner,
89-
condition = None).where(pythonUDF).select('a, 'b).analyze
90-
comparePlanWithCrossJoinEnable(query, expected)
110+
condition = None).where(unevaluableJoinCond).select('a, 'b).analyze
111+
comparePlanWithCrossJoinEnable(query1, expected1)
112+
113+
// evaluable PythonUDF will not be touched
114+
val query2 = testRelationLeft.join(
115+
testRelationRight,
116+
joinType = LeftSemi,
117+
condition = Some(evaluableJoinCond))
118+
comparePlans(Optimize.execute(query2), query2)
91119
}
92120

93-
test("python udf and common condition") {
121+
test("unevaluable python udf and common condition") {
94122
val query = testRelationLeft.join(
95123
testRelationRight,
96124
joinType = Inner,
97-
condition = Some(pythonUDF && 'a.attr === 'c.attr))
125+
condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr))
98126
val expected = testRelationLeft.join(
99127
testRelationRight,
100128
joinType = Inner,
101-
condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze
129+
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze
102130
val optimized = Optimize.execute(query.analyze)
103131
comparePlans(optimized, expected)
104132
}
105133

106-
test("python udf or common condition") {
134+
test("unevaluable python udf or common condition") {
107135
val query = testRelationLeft.join(
108136
testRelationRight,
109137
joinType = Inner,
110-
condition = Some(pythonUDF || 'a.attr === 'c.attr))
138+
condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr))
111139
val expected = testRelationLeft.join(
112140
testRelationRight,
113141
joinType = Inner,
114-
condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze
142+
condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze
115143
comparePlanWithCrossJoinEnable(query, expected)
116144
}
117145

118-
test("pull out whole complex condition with multiple python udf") {
146+
test("pull out whole complex condition with multiple unevaluable python udf") {
119147
val pythonUDF1 = PythonUDF("pythonUDF1", null,
120148
BooleanType,
121-
Seq.empty,
149+
Seq(attrA, attrC),
122150
PythonEvalType.SQL_BATCHED_UDF,
123151
udfDeterministic = true)
124-
val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1
152+
val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1
125153

126154
val query = testRelationLeft.join(
127155
testRelationRight,
@@ -134,13 +162,13 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
134162
comparePlanWithCrossJoinEnable(query, expected)
135163
}
136164

137-
test("partial pull out complex condition with multiple python udf") {
165+
test("partial pull out complex condition with multiple unevaluable python udf") {
138166
val pythonUDF1 = PythonUDF("pythonUDF1", null,
139167
BooleanType,
140-
Seq.empty,
168+
Seq(attrA, attrC),
141169
PythonEvalType.SQL_BATCHED_UDF,
142170
udfDeterministic = true)
143-
val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr
171+
val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr
144172

145173
val query = testRelationLeft.join(
146174
testRelationRight,
@@ -149,23 +177,41 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
149177
val expected = testRelationLeft.join(
150178
testRelationRight,
151179
joinType = Inner,
152-
condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze
180+
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze
181+
val optimized = Optimize.execute(query.analyze)
182+
comparePlans(optimized, expected)
183+
}
184+
185+
test("pull out unevaluable python udf when it's mixed with evaluable one") {
186+
val query = testRelationLeft.join(
187+
testRelationRight,
188+
joinType = Inner,
189+
condition = Some(evaluableJoinCond && unevaluableJoinCond))
190+
val expected = testRelationLeft.join(
191+
testRelationRight,
192+
joinType = Inner,
193+
condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze
153194
val optimized = Optimize.execute(query.analyze)
154195
comparePlans(optimized, expected)
155196
}
156197

157198
test("throw an exception for not support join type") {
158199
for (joinType <- unsupportedJoinTypes) {
159-
val thrownException = the [AnalysisException] thrownBy {
200+
val e = intercept[AnalysisException] {
160201
val query = testRelationLeft.join(
161202
testRelationRight,
162203
joinType,
163-
condition = Some(pythonUDF))
204+
condition = Some(unevaluableJoinCond))
164205
Optimize.execute(query.analyze)
165206
}
166-
assert(thrownException.message.contentEquals(
207+
assert(e.message.contentEquals(
167208
s"Using PythonUDF in join condition of join type $joinType is not supported."))
209+
210+
val query2 = testRelationLeft.join(
211+
testRelationRight,
212+
joinType,
213+
condition = Some(evaluableJoinCond))
214+
comparePlans(Optimize.execute(query2), query2)
168215
}
169216
}
170217
}
171-

0 commit comments

Comments
 (0)