Skip to content

Commit 81bb5cb

Browse files
cloud-fankai-chi
authored andcommitted
[SPARK-26147][SQL] only pull out unevaluable python udf from join condition
apache#22326 made a mistake that, not all python UDFs are unevaluable in join condition. Only python UDFs that refer to attributes from both join side are unevaluable. This PR fixes this mistake. a new test Closes apache#23153 from cloud-fan/join. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit affe809) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 1270c09 commit 81bb5cb

File tree

3 files changed

+240
-11
lines changed

3 files changed

+240
-11
lines changed

python/pyspark/sql/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,18 @@ def test_udf_in_join_condition(self):
564564
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
565565
self.assertEqual(df.collect(), [Row(a=1, b=1)])
566566

567+
def test_udf_in_left_outer_join_condition(self):
568+
# regression test for SPARK-26147
569+
from pyspark.sql.functions import udf, col
570+
left = self.spark.createDataFrame([Row(a=1)])
571+
right = self.spark.createDataFrame([Row(b=1)])
572+
f = udf(lambda a: str(a), StringType())
573+
# The join condition can't be pushed down, as it refers to attributes from both sides.
574+
# The Python UDF only refer to attributes from one side, so it's evaluable.
575+
df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
576+
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
577+
self.assertEqual(df.collect(), [Row(a=1, b=1)])
578+
567579
def test_udf_in_left_semi_join_condition(self):
568580
# regression test for SPARK-25314
569581
from pyspark.sql.functions import udf

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
155155
}
156156

157157
/**
158-
* PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
159-
* and pull them out from join condition. For python udf accessing attributes from only one side,
160-
* they are pushed down by operation push down rules. If not (e.g. user disables filter push
161-
* down rules), we need to pull them out in this rule too.
158+
* PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides.
159+
* See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them
160+
* out from join condition.
162161
*/
163162
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
164-
def hasPythonUDF(expression: Expression): Boolean = {
165-
expression.collectFirst { case udf: PythonUDF => udf }.isDefined
163+
164+
private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = {
165+
expr.find { e =>
166+
PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right)
167+
}.isDefined
166168
}
167169

168170
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
169-
case j @ Join(_, _, joinType, condition)
170-
if condition.isDefined && hasPythonUDF(condition.get) =>
171+
case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) =>
171172
if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
172173
// The current strategy only support InnerLike and LeftSemi join because for other type,
173174
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
@@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH
179180
}
180181
// If condition expression contains python udf, it will be moved out from
181182
// the new join conditions.
182-
val (udf, rest) =
183-
splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
183+
val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j))
184184
val newCondition = if (rest.isEmpty) {
185-
logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
185+
logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," +
186186
s" it will be moved out and the join plan will be turned to cross join.")
187187
None
188188
} else {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.api.python.PythonEvalType
21+
import org.apache.spark.sql.AnalysisException
22+
import org.apache.spark.sql.catalyst.dsl.expressions._
23+
import org.apache.spark.sql.catalyst.dsl.plans._
24+
import org.apache.spark.sql.catalyst.expressions.PythonUDF
25+
import org.apache.spark.sql.catalyst.plans._
26+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
27+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
28+
import org.apache.spark.sql.internal.SQLConf._
29+
import org.apache.spark.sql.types.{BooleanType, IntegerType}
30+
31+
class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
32+
33+
object Optimize extends RuleExecutor[LogicalPlan] {
34+
val batches =
35+
Batch("Extract PythonUDF From JoinCondition", Once,
36+
PullOutPythonUDFInJoinCondition) ::
37+
Batch("Check Cartesian Products", Once,
38+
CheckCartesianProducts) :: Nil
39+
}
40+
41+
val attrA = 'a.int
42+
val attrB = 'b.int
43+
val attrC = 'c.int
44+
val attrD = 'd.int
45+
46+
val testRelationLeft = LocalRelation(attrA, attrB)
47+
val testRelationRight = LocalRelation(attrC, attrD)
48+
49+
// This join condition refers to attributes from 2 tables, but the PythonUDF inside it only
50+
// refer to attributes from one side.
51+
val evaluableJoinCond = {
52+
val pythonUDF = PythonUDF("evaluable", null,
53+
IntegerType,
54+
Seq(attrA),
55+
PythonEvalType.SQL_BATCHED_UDF,
56+
udfDeterministic = true)
57+
pythonUDF === attrC
58+
}
59+
60+
// This join condition is a PythonUDF which refers to attributes from 2 tables.
61+
val unevaluableJoinCond = PythonUDF("unevaluable", null,
62+
BooleanType,
63+
Seq(attrA, attrC),
64+
PythonEvalType.SQL_BATCHED_UDF,
65+
udfDeterministic = true)
66+
67+
val unsupportedJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti)
68+
69+
private def comparePlanWithCrossJoinEnable(query: LogicalPlan, expected: LogicalPlan): Unit = {
70+
// AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false
71+
val exception = intercept[AnalysisException] {
72+
Optimize.execute(query.analyze)
73+
}
74+
assert(exception.message.startsWith("Detected implicit cartesian product"))
75+
76+
// pull out the python udf while set spark.sql.crossJoin.enabled=true
77+
withSQLConf(CROSS_JOINS_ENABLED.key -> "true") {
78+
val optimized = Optimize.execute(query.analyze)
79+
comparePlans(optimized, expected)
80+
}
81+
}
82+
83+
test("inner join condition with python udf") {
84+
val query1 = testRelationLeft.join(
85+
testRelationRight,
86+
joinType = Inner,
87+
condition = Some(unevaluableJoinCond))
88+
val expected1 = testRelationLeft.join(
89+
testRelationRight,
90+
joinType = Inner,
91+
condition = None).where(unevaluableJoinCond).analyze
92+
comparePlanWithCrossJoinEnable(query1, expected1)
93+
94+
// evaluable PythonUDF will not be touched
95+
val query2 = testRelationLeft.join(
96+
testRelationRight,
97+
joinType = Inner,
98+
condition = Some(evaluableJoinCond))
99+
comparePlans(Optimize.execute(query2), query2)
100+
}
101+
102+
test("left semi join condition with python udf") {
103+
val query1 = testRelationLeft.join(
104+
testRelationRight,
105+
joinType = LeftSemi,
106+
condition = Some(unevaluableJoinCond))
107+
val expected1 = testRelationLeft.join(
108+
testRelationRight,
109+
joinType = Inner,
110+
condition = None).where(unevaluableJoinCond).select('a, 'b).analyze
111+
comparePlanWithCrossJoinEnable(query1, expected1)
112+
113+
// evaluable PythonUDF will not be touched
114+
val query2 = testRelationLeft.join(
115+
testRelationRight,
116+
joinType = LeftSemi,
117+
condition = Some(evaluableJoinCond))
118+
comparePlans(Optimize.execute(query2), query2)
119+
}
120+
121+
test("unevaluable python udf and common condition") {
122+
val query = testRelationLeft.join(
123+
testRelationRight,
124+
joinType = Inner,
125+
condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr))
126+
val expected = testRelationLeft.join(
127+
testRelationRight,
128+
joinType = Inner,
129+
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze
130+
val optimized = Optimize.execute(query.analyze)
131+
comparePlans(optimized, expected)
132+
}
133+
134+
test("unevaluable python udf or common condition") {
135+
val query = testRelationLeft.join(
136+
testRelationRight,
137+
joinType = Inner,
138+
condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr))
139+
val expected = testRelationLeft.join(
140+
testRelationRight,
141+
joinType = Inner,
142+
condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze
143+
comparePlanWithCrossJoinEnable(query, expected)
144+
}
145+
146+
test("pull out whole complex condition with multiple unevaluable python udf") {
147+
val pythonUDF1 = PythonUDF("pythonUDF1", null,
148+
BooleanType,
149+
Seq(attrA, attrC),
150+
PythonEvalType.SQL_BATCHED_UDF,
151+
udfDeterministic = true)
152+
val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1
153+
154+
val query = testRelationLeft.join(
155+
testRelationRight,
156+
joinType = Inner,
157+
condition = Some(condition))
158+
val expected = testRelationLeft.join(
159+
testRelationRight,
160+
joinType = Inner,
161+
condition = None).where(condition).analyze
162+
comparePlanWithCrossJoinEnable(query, expected)
163+
}
164+
165+
test("partial pull out complex condition with multiple unevaluable python udf") {
166+
val pythonUDF1 = PythonUDF("pythonUDF1", null,
167+
BooleanType,
168+
Seq(attrA, attrC),
169+
PythonEvalType.SQL_BATCHED_UDF,
170+
udfDeterministic = true)
171+
val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr
172+
173+
val query = testRelationLeft.join(
174+
testRelationRight,
175+
joinType = Inner,
176+
condition = Some(condition))
177+
val expected = testRelationLeft.join(
178+
testRelationRight,
179+
joinType = Inner,
180+
condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze
181+
val optimized = Optimize.execute(query.analyze)
182+
comparePlans(optimized, expected)
183+
}
184+
185+
test("pull out unevaluable python udf when it's mixed with evaluable one") {
186+
val query = testRelationLeft.join(
187+
testRelationRight,
188+
joinType = Inner,
189+
condition = Some(evaluableJoinCond && unevaluableJoinCond))
190+
val expected = testRelationLeft.join(
191+
testRelationRight,
192+
joinType = Inner,
193+
condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze
194+
val optimized = Optimize.execute(query.analyze)
195+
comparePlans(optimized, expected)
196+
}
197+
198+
test("throw an exception for not support join type") {
199+
for (joinType <- unsupportedJoinTypes) {
200+
val e = intercept[AnalysisException] {
201+
val query = testRelationLeft.join(
202+
testRelationRight,
203+
joinType,
204+
condition = Some(unevaluableJoinCond))
205+
Optimize.execute(query.analyze)
206+
}
207+
assert(e.message.contentEquals(
208+
s"Using PythonUDF in join condition of join type $joinType is not supported."))
209+
210+
val query2 = testRelationLeft.join(
211+
testRelationRight,
212+
joinType,
213+
condition = Some(evaluableJoinCond))
214+
comparePlans(Optimize.execute(query2), query2)
215+
}
216+
}
217+
}

0 commit comments

Comments
 (0)