Skip to content

Commit 5389013

Browse files
Davies Liudavies
Davies Liu
authored andcommitted
[SPARK-15888] [SQL] fix Python UDF with aggregate
## What changes were proposed in this pull request? After we move the ExtractPythonUDF rule into physical plan, Python UDF can't work on top of aggregate anymore, because they can't be evaluated before aggregate, should be evaluated after aggregate. This PR add another rule to extract these kind of Python UDF from logical aggregate, create a Project on top of Aggregate. ## How was this patch tested? Added regression tests. The plan of added test query looks like this: ``` == Parsed Logical Plan == 'Project [<lambda>('k, 's) AS t#26] +- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Analyzed Logical Plan == t: int Project [<lambda>(k#17, s#22L) AS t#26] +- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS k#17, sum(cast(<lambda>(value#6) as bigint)) AS s#22L] +- LogicalRDD [key#5L, value#6] == Optimized Logical Plan == Project [<lambda>(agg#29, agg#30L) AS t#26] +- Aggregate [<lambda>(key#5L)], [<lambda>(key#5L) AS agg#29, sum(cast(<lambda>(value#6) as bigint)) AS agg#30L] +- LogicalRDD [key#5L, value#6] == Physical Plan == *Project [pythonUDF0#37 AS t#26] +- BatchEvalPython [<lambda>(agg#29, agg#30L)], [agg#29, agg#30L, pythonUDF0#37] +- *HashAggregate(key=[<lambda>(key#5L)#31], functions=[sum(cast(<lambda>(value#6) as bigint))], output=[agg#29,agg#30L]) +- Exchange hashpartitioning(<lambda>(key#5L)#31, 200) +- *HashAggregate(key=[pythonUDF0#34 AS <lambda>(key#5L)#31], functions=[partial_sum(cast(pythonUDF1#35 as bigint))], output=[<lambda>(key#5L)#31,sum#33L]) +- BatchEvalPython [<lambda>(key#5L), <lambda>(value#6)], [key#5L, value#6, pythonUDF0#34, pythonUDF1#35] +- Scan ExistingRDD[key#5L,value#6] ``` Author: Davies Liu <davies@databricks.com> Closes #13682 from davies/fix_py_udf.
1 parent 279bd4a commit 5389013

File tree

4 files changed

+77
-11
lines changed

4 files changed

+77
-11
lines changed

python/pyspark/sql/tests.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,21 @@ def test_broadcast_in_udf(self):
339339

340340
def test_udf_with_aggregate_function(self):
341341
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
342-
from pyspark.sql.functions import udf, col
342+
from pyspark.sql.functions import udf, col, sum
343343
from pyspark.sql.types import BooleanType
344344

345345
my_filter = udf(lambda a: a == 1, BooleanType())
346346
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
347347
self.assertEqual(sel.collect(), [Row(key=1)])
348348

349+
my_copy = udf(lambda x: x, IntegerType())
350+
my_add = udf(lambda a, b: int(a + b), IntegerType())
351+
my_strlen = udf(lambda x: len(x), IntegerType())
352+
sel = df.groupBy(my_copy(col("key")).alias("k"))\
353+
.agg(sum(my_strlen(col("value"))).alias("s"))\
354+
.select(my_add(col("k"), col("s")).alias("t"))
355+
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
356+
349357
def test_basic_functions(self):
350358
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
351359
df = self.spark.read.json(rdd)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
2020
import org.apache.spark.sql.ExperimentalMethods
2121
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
2222
import org.apache.spark.sql.catalyst.optimizer.Optimizer
23+
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
2324
import org.apache.spark.sql.internal.SQLConf
2425

2526
class SparkOptimizer(
@@ -28,6 +29,7 @@ class SparkOptimizer(
2829
experimentalMethods: ExperimentalMethods)
2930
extends Optimizer(catalog, conf) {
3031

31-
override def batches: Seq[Batch] = super.batches :+ Batch(
32-
"User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
32+
override def batches: Seq[Batch] = super.batches :+
33+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
34+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
3335
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
4646

4747
def children: Seq[SparkPlan] = child :: Nil
4848

49+
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
50+
4951
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
5052
udf.children match {
5153
case Seq(u: PythonUDF) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,68 @@
1818
package org.apache.spark.sql.execution.python
1919

2020
import scala.collection.mutable
21+
import scala.collection.mutable.ArrayBuffer
2122

22-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
25+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
2326
import org.apache.spark.sql.catalyst.rules.Rule
2427
import org.apache.spark.sql.execution
2528
import org.apache.spark.sql.execution.SparkPlan
2629

30+
31+
/**
32+
* Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
33+
* grouping key, evaluate them after aggregate.
34+
*/
35+
private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
36+
37+
/**
38+
* Returns whether the expression could only be evaluated within aggregate.
39+
*/
40+
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
41+
e.isInstanceOf[AggregateExpression] ||
42+
agg.groupingExpressions.exists(_.semanticEquals(e))
43+
}
44+
45+
private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
46+
expr.find {
47+
e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
48+
}.isDefined
49+
}
50+
51+
private def extract(agg: Aggregate): LogicalPlan = {
52+
val projList = new ArrayBuffer[NamedExpression]()
53+
val aggExpr = new ArrayBuffer[NamedExpression]()
54+
agg.aggregateExpressions.foreach { expr =>
55+
if (hasPythonUdfOverAggregate(expr, agg)) {
56+
// Python UDF can only be evaluated after aggregate
57+
val newE = expr transformDown {
58+
case e: Expression if belongAggregate(e, agg) =>
59+
val alias = e match {
60+
case a: NamedExpression => a
61+
case o => Alias(e, "agg")()
62+
}
63+
aggExpr += alias
64+
alias.toAttribute
65+
}
66+
projList += newE.asInstanceOf[NamedExpression]
67+
} else {
68+
aggExpr += expr
69+
projList += expr.toAttribute
70+
}
71+
}
72+
// There is no Python UDF over aggregate expression
73+
Project(projList, agg.copy(aggregateExpressions = aggExpr))
74+
}
75+
76+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
77+
case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
78+
extract(agg)
79+
}
80+
}
81+
82+
2783
/**
2884
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
2985
* alone in a batch.
@@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
59115
}
60116

61117
/**
62-
* Extract all the PythonUDFs from the current operator.
118+
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
63119
*/
64-
def extract(plan: SparkPlan): SparkPlan = {
120+
private def extract(plan: SparkPlan): SparkPlan = {
65121
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
122+
// ignore the PythonUDF that come from second/third aggregate, which is not used
123+
.filter(udf => udf.references.subsetOf(plan.inputSet))
66124
if (udfs.isEmpty) {
67125
// If there aren't any, we are done.
68126
plan
@@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
89147
// Other cases are disallowed as they are ambiguous or would require a cartesian
90148
// product.
91149
udfs.filterNot(attributeMap.contains).foreach { udf =>
92-
if (udf.references.subsetOf(plan.inputSet)) {
93-
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
94-
} else {
95-
sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
96-
}
150+
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
97151
}
98152

99153
val rewritten = plan.transformExpressions {

0 commit comments

Comments
 (0)