Skip to content

Commit 53eb858

Browse files
xuanyuankingcloud-fan
authored andcommitted
[SPARK-25314][SQL] Fix Python UDF accessing attributes from both side of join in join conditions
## What changes were proposed in this pull request? Thanks for bahchis reporting this. It is more like a follow up work for #16581, this PR fix the scenario of Python UDF accessing attributes from both side of join in join condition. ## How was this patch tested? Add regression tests in PySpark and `BatchEvalPythonExecSuite`. Closes #22326 from xuanyuanking/SPARK-25314. Authored-by: Yuanjian Li <xyliyuanjian@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 2a8cbfd) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 0b4e581 commit 53eb858

File tree

3 files changed

+119
-2
lines changed

3 files changed

+119
-2
lines changed

python/pyspark/sql/tests.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,70 @@ def test_udf_in_filter_on_top_of_join(self):
552552
df = left.crossJoin(right).filter(f("a", "b"))
553553
self.assertEqual(df.collect(), [Row(a=1, b=1)])
554554

555+
def test_udf_in_join_condition(self):
556+
# regression test for SPARK-25314
557+
from pyspark.sql.functions import udf
558+
left = self.spark.createDataFrame([Row(a=1)])
559+
right = self.spark.createDataFrame([Row(b=1)])
560+
f = udf(lambda a, b: a == b, BooleanType())
561+
df = left.join(right, f("a", "b"))
562+
with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
563+
df.collect()
564+
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
565+
self.assertEqual(df.collect(), [Row(a=1, b=1)])
566+
567+
def test_udf_in_left_semi_join_condition(self):
568+
# regression test for SPARK-25314
569+
from pyspark.sql.functions import udf
570+
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
571+
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
572+
f = udf(lambda a, b: a == b, BooleanType())
573+
df = left.join(right, f("a", "b"), "leftsemi")
574+
with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
575+
df.collect()
576+
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
577+
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
578+
579+
def test_udf_and_common_filter_in_join_condition(self):
580+
# regression test for SPARK-25314
581+
# test the complex scenario with both udf and common filter
582+
from pyspark.sql.functions import udf
583+
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
584+
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
585+
f = udf(lambda a, b: a == b, BooleanType())
586+
df = left.join(right, [f("a", "b"), left.a1 == right.b1])
587+
# do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
588+
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
589+
590+
def test_udf_and_common_filter_in_left_semi_join_condition(self):
591+
# regression test for SPARK-25314
592+
# test the complex scenario with both udf and common filter
593+
from pyspark.sql.functions import udf
594+
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
595+
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
596+
f = udf(lambda a, b: a == b, BooleanType())
597+
df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi")
598+
# do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
599+
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
600+
601+
def test_udf_not_supported_in_join_condition(self):
602+
# regression test for SPARK-25314
603+
# test python udf is not supported in join type besides left_semi and inner join.
604+
from pyspark.sql.functions import udf
605+
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
606+
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
607+
f = udf(lambda a, b: a == b, BooleanType())
608+
609+
def runWithJoinType(join_type, type_string):
610+
with self.assertRaisesRegexp(
611+
AnalysisException,
612+
'Using PythonUDF.*%s is not supported.' % type_string):
613+
left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
614+
runWithJoinType("full", "FullOuter")
615+
runWithJoinType("left", "LeftOuter")
616+
runWithJoinType("right", "RightOuter")
617+
runWithJoinType("leftanti", "LeftAnti")
618+
555619
def test_udf_without_arguments(self):
556620
self.spark.catalog.registerFunction("foo", lambda: "bar")
557621
[row] = self.spark.sql("SELECT foo()").collect()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
165165
Batch("LocalRelation", fixedPoint,
166166
ConvertToLocalRelation,
167167
PropagateEmptyRelation) :+
168-
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
168+
Batch("Extract PythonUDF From JoinCondition", Once,
169+
PullOutPythonUDFInJoinCondition) :+
170+
// The following batch should be executed after batch "Join Reorder" "LocalRelation" and
171+
// "Extract PythonUDF From JoinCondition".
169172
Batch("Check Cartesian Products", Once,
170173
CheckCartesianProducts) :+
171174
Batch("RewriteSubquery", Once,
@@ -202,7 +205,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
202205
ReplaceDistinctWithAggregate.ruleName ::
203206
PullupCorrelatedPredicates.ruleName ::
204207
RewriteCorrelatedScalarSubquery.ruleName ::
205-
RewritePredicateSubquery.ruleName :: Nil
208+
RewritePredicateSubquery.ruleName ::
209+
PullOutPythonUDFInJoinCondition.ruleName :: Nil
206210

207211
/**
208212
* Optimize all the subqueries inside expression.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.annotation.tailrec
2121

22+
import org.apache.spark.sql.AnalysisException
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
2425
import org.apache.spark.sql.catalyst.plans._
@@ -152,3 +153,51 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
152153
if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType))
153154
}
154155
}
156+
157+
/**
158+
* PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
159+
* and pull them out from join condition. For python udf accessing attributes from only one side,
160+
* they are pushed down by operation push down rules. If not (e.g. user disables filter push
161+
* down rules), we need to pull them out in this rule too.
162+
*/
163+
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
164+
def hasPythonUDF(expression: Expression): Boolean = {
165+
expression.collectFirst { case udf: PythonUDF => udf }.isDefined
166+
}
167+
168+
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
169+
case j @ Join(_, _, joinType, condition)
170+
if condition.isDefined && hasPythonUDF(condition.get) =>
171+
if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
172+
// The current strategy only support InnerLike and LeftSemi join because for other type,
173+
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
174+
// the plan here, it'll still get a an invalid PythonUDF RuntimeException with message
175+
// `requires attributes from more than one child`, we throw firstly here for better
176+
// readable information.
177+
throw new AnalysisException("Using PythonUDF in join condition of join type" +
178+
s" $joinType is not supported.")
179+
}
180+
// If condition expression contains python udf, it will be moved out from
181+
// the new join conditions.
182+
val (udf, rest) =
183+
splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
184+
val newCondition = if (rest.isEmpty) {
185+
logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
186+
s" it will be moved out and the join plan will be turned to cross join.")
187+
None
188+
} else {
189+
Some(rest.reduceLeft(And))
190+
}
191+
val newJoin = j.copy(condition = newCondition)
192+
joinType match {
193+
case _: InnerLike => Filter(udf.reduceLeft(And), newJoin)
194+
case LeftSemi =>
195+
Project(
196+
j.left.output.map(_.toAttribute),
197+
Filter(udf.reduceLeft(And), newJoin.copy(joinType = Inner)))
198+
case _ =>
199+
throw new AnalysisException("Using PythonUDF in join condition of join type" +
200+
s" $joinType is not supported.")
201+
}
202+
}
203+
}

0 commit comments

Comments
 (0)