Skip to content

Commit 613986d

Browse files
Davies Liucmonkey
Davies Liu
authored andcommitted
[SPARK-18589][SQL] Fix Python UDF accessing attributes from both side of join
## What changes were proposed in this pull request? PythonUDF is unevaluable, which can not be used inside a join condition, currently the optimizer will push a PythonUDF which accessing both side of join into the join condition, then the query will fail to plan. This PR fix this issue by checking the expression is evaluable or not before pushing it into Join. ## How was this patch tested? Add a regression test. Author: Davies Liu <davies@databricks.com> Closes apache#16581 from davies/pyudf_join.
1 parent 2002973 commit 613986d

File tree

5 files changed

+30
-13
lines changed

5 files changed

+30
-13
lines changed

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,15 @@ def test_udf_in_filter_on_top_of_outer_join(self):
342342
df = df.withColumn('b', udf(lambda x: 'x')(df.a))
343343
self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
344344

345+
def test_udf_in_filter_on_top_of_join(self):
346+
# regression test for SPARK-18589
347+
from pyspark.sql.functions import udf
348+
left = self.spark.createDataFrame([Row(a=1)])
349+
right = self.spark.createDataFrame([Row(b=1)])
350+
f = udf(lambda a, b: a == b, BooleanType())
351+
df = left.crossJoin(right).filter(f("a", "b"))
352+
self.assertEqual(df.collect(), [Row(a=1, b=1)])
353+
345354
def test_udf_without_arguments(self):
346355
self.spark.catalog.registerFunction("foo", lambda: "bar")
347356
[row] = self.spark.sql("SELECT foo()").collect()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
2323
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2424
import org.apache.spark.sql.catalyst.util.TypeUtils
2525
import org.apache.spark.sql.types._
26-
import org.apache.spark.util.Utils
2726

2827

2928
object InterpretedPredicate {
@@ -86,6 +85,18 @@ trait PredicateHelper {
8685
*/
8786
protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
8887
expr.references.subsetOf(plan.outputSet)
88+
89+
/**
90+
* Returns true iff `expr` could be evaluated as a condition within join.
91+
*/
92+
protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match {
93+
case e: SubqueryExpression =>
94+
// non-correlated subquery will be replaced as literal
95+
e.children.isEmpty
96+
case a: AttributeReference => true
97+
case e: Unevaluable => false
98+
case e => e.children.forall(canEvaluateWithinJoin)
99+
}
89100
}
90101

91102
@ExpressionDescription(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
893893
val newRight = rightFilterConditions.
894894
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
895895
val (newJoinConditions, others) =
896-
commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
896+
commonFilterCondition.partition(canEvaluateWithinJoin)
897897
val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
898898

899899
val join = Join(newLeft, newRight, joinType, newJoinCond)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
4646
: LogicalPlan = {
4747
assert(input.size >= 2)
4848
if (input.size == 2) {
49-
val (joinConditions, others) = conditions.partition(
50-
e => !SubqueryExpression.hasCorrelatedSubquery(e))
49+
val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
5150
val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1))
5251
val innerJoinType = (leftJoinType, rightJoinType) match {
5352
case (Inner, Inner) => Inner
@@ -75,7 +74,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
7574

7675
val joinedRefs = left.outputSet ++ right.outputSet
7776
val (joinConditions, others) = conditions.partition(
78-
e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
77+
e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
7978
val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And))
8079

8180
// should not have reference to same logical plan

sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
2121
import scala.collection.mutable.ArrayBuffer
2222

2323
import org.apache.spark.api.python.PythonFunction
24-
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
24+
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, In}
2525
import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
2626
import org.apache.spark.sql.test.SharedSQLContext
2727
import org.apache.spark.sql.types.BooleanType
@@ -86,13 +86,11 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
8686
test("Python UDF refers to the attributes from more than one child") {
8787
val df = Seq(("Hello", 4)).toDF("a", "b")
8888
val df2 = Seq(("Hello", 4)).toDF("c", "d")
89-
val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
90-
91-
val e = intercept[RuntimeException] {
92-
joinDF.queryExecution.executedPlan
93-
}.getMessage
94-
assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child")
95-
.forall(e.contains))
89+
val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
90+
val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect {
91+
case b: BatchEvalPythonExec => b
92+
}
93+
assert(qualifiedPlanNodes.size == 1)
9694
}
9795
}
9896

0 commit comments

Comments
 (0)