Skip to content

[SPARK-25314][SQL] Fix Python UDF accessing attributes from both side of join in join conditions #22326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9c579bd
Fix Python UDF accessing attibutes from both side of join in join con…
xuanyuanking Sep 4, 2018
9ea1cf6
Address comments
xuanyuanking Sep 4, 2018
b6b0aa6
Add crossJoinEnabled checking logic.
xuanyuanking Sep 4, 2018
b626fa7
Address comments
xuanyuanking Sep 4, 2018
53dd028
Fix the left semi join and more tests
xuanyuanking Sep 5, 2018
4ca7fd1
fix python style
xuanyuanking Sep 5, 2018
1109eb3
Address comments
xuanyuanking Sep 7, 2018
c6345fe
Delete mistake commit
xuanyuanking Sep 7, 2018
83660d5
Address comments
xuanyuanking Sep 8, 2018
fdc86ca
Add UT for common filter and udf
xuanyuanking Sep 8, 2018
fbf32f4
limit the change scope only to PythonUDF
xuanyuanking Sep 13, 2018
292b09c
Reimplement the logic in Analyzer instead of Optimizer
xuanyuanking Sep 22, 2018
6749a96
config should be set in analyzer
xuanyuanking Sep 23, 2018
a598a4e
Move the rule to optimizer
xuanyuanking Sep 25, 2018
a2c8ddd
style fix
xuanyuanking Sep 25, 2018
005bb3f
Address comments from Wenchen
xuanyuanking Sep 25, 2018
b0dfab3
fix exhaustive match
xuanyuanking Sep 25, 2018
306fcb9
Move cross join detection logic into CheckCartesianProducts
xuanyuanking Sep 25, 2018
98cd3cc
Address comment
xuanyuanking Sep 26, 2018
d1db33a
Revert the changes of original plan and enhance UT
xuanyuanking Sep 26, 2018
87f0f50
Address comments
xuanyuanking Sep 26, 2018
d2739af
Address comments from Marco
xuanyuanking Sep 26, 2018
7f66954
Address comment
xuanyuanking Sep 26, 2018
2b6977d
Delete unnecessary end-to-end tests
xuanyuanking Sep 27, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,70 @@ def test_udf_in_filter_on_top_of_join(self):
df = left.crossJoin(right).filter(f("a", "b"))
self.assertEqual(df.collect(), [Row(a=1, b=1)])

def test_udf_in_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
df = left.join(right, f("a", "b"))
with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
df.collect()
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])

def test_udf_in_left_semi_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
df = left.join(right, f("a", "b"), "leftsemi")
with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
df.collect()
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])

def test_udf_and_common_filter_in_join_condition(self):
# regression test for SPARK-25314
# test the complex scenario with both udf and common filter
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
df = left.join(right, [f("a", "b"), left.a1 == right.b1])
# do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])

def test_udf_and_common_filter_in_left_semi_join_condition(self):
# regression test for SPARK-25314
# test the complex scenario with both udf and common filter
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi")
# do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])

def test_udf_not_supported_in_join_condition(self):
# regression test for SPARK-25314
# test python udf is not supported in join type besides left_semi and inner join.
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())

def runWithJoinType(join_type, type_string):
with self.assertRaisesRegexp(
AnalysisException,
'Using PythonUDF.*%s is not supported.' % type_string):
left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
runWithJoinType("full", "FullOuter")
runWithJoinType("left", "LeftOuter")
runWithJoinType("right", "RightOuter")
runWithJoinType("leftanti", "LeftAnti")

def test_udf_without_arguments(self):
self.spark.catalog.registerFunction("foo", lambda: "bar")
[row] = self.spark.sql("SELECT foo()").collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation,
PropagateEmptyRelation) :+
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
Batch("Extract PythonUDF From JoinCondition", Once,
PullOutPythonUDFInJoinCondition) :+
// The following batch should be executed after batch "Join Reorder" "LocalRelation" and
// "Extract PythonUDF From JoinCondition".
Batch("Check Cartesian Products", Once,
CheckCartesianProducts) :+
Batch("RewriteSubquery", Once,
Expand Down Expand Up @@ -202,7 +205,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
ReplaceDistinctWithAggregate.ruleName ::
PullupCorrelatedPredicates.ruleName ::
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName :: Nil
RewritePredicateSubquery.ruleName ::
PullOutPythonUDFInJoinCondition.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.annotation.tailrec

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -152,3 +153,51 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType))
}
}

/**
* PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
* and pull them out from join condition. For python udf accessing attributes from only one side,
* they are pushed down by operation push down rules. If not (e.g. user disables filter push
* down rules), we need to pull them out in this rule too.
*/
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
def hasPythonUDF(expression: Expression): Boolean = {
expression.collectFirst { case udf: PythonUDF => udf }.isDefined
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case j @ Join(_, _, joinType, condition)
if condition.isDefined && hasPythonUDF(condition.get) =>
if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
// The current strategy only support InnerLike and LeftSemi join because for other type,
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
// the plan here, it'll still get a an invalid PythonUDF RuntimeException with message
// `requires attributes from more than one child`, we throw firstly here for better
// readable information.
throw new AnalysisException("Using PythonUDF in join condition of join type" +
s" $joinType is not supported.")
}
// If condition expression contains python udf, it will be moved out from
// the new join conditions.
val (udf, rest) =
splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
val newCondition = if (rest.isEmpty) {
logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
s" it will be moved out and the join plan will be turned to cross join.")
None
} else {
Some(rest.reduceLeft(And))
}
val newJoin = j.copy(condition = newCondition)
joinType match {
case _: InnerLike => Filter(udf.reduceLeft(And), newJoin)
case LeftSemi =>
Project(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we are simulating a left semi join here. Seems we can do the same thing for left anti join.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried two ways to implement LeftAnti here:

  1. Use the Except(join.left, left semi result, isAll=false) to simulate, it is banned by strategy and actually also no plan for Except.
    case logical.Except(left, right, false) =>
    throw new IllegalStateException(
    "logical except operator should have been replaced by anti-join in the optimizer")
  2. Also use cross join and filter to simulate, but maybe it can't reached when there's only udf in anti join condition. Because after cross join, it's hard to roll back to original status. UDF+ normal common condition can be simulated by
Project(
j.left.output.map(_.toAttribute),
Filter(Not(udf.reduceLeft(And)),
newJoin.copy(joinType = Inner, condition = not(rest.reduceLeft(And)))))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, let's leave left anti join then, thanks for trying!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks :)

j.left.output.map(_.toAttribute),
Filter(udf.reduceLeft(And), newJoin.copy(joinType = Inner)))
case _ =>
throw new AnalysisException("Using PythonUDF in join condition of join type" +
s" $joinType is not supported.")
}
}
}