Skip to content

Commit 31b42eb

Browse files
author
Davies Liu
committed
fix Python UDF with aggregate
1 parent dae4d5d commit 31b42eb

File tree

4 files changed

+77
-11
lines changed

4 files changed

+77
-11
lines changed

python/pyspark/sql/tests.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,21 @@ def test_broadcast_in_udf(self):
339339

340340
def test_udf_with_aggregate_function(self):
341341
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
342-
from pyspark.sql.functions import udf, col
342+
from pyspark.sql.functions import udf, col, sum
343343
from pyspark.sql.types import BooleanType
344344

345345
my_filter = udf(lambda a: a == 1, BooleanType())
346346
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
347347
self.assertEqual(sel.collect(), [Row(key=1)])
348348

349+
my_copy = udf(lambda x: x, IntegerType())
350+
my_add = udf(lambda a, b: int(a + b), IntegerType())
351+
my_strlen = udf(lambda x: len(x), IntegerType())
352+
sel = df.groupBy(my_copy(col("key")).alias("k"))\
353+
.agg(sum(my_strlen(col("value"))).alias("s"))\
354+
.select(my_add(col("k"), col("s")).alias("t"))
355+
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
356+
349357
def test_basic_functions(self):
350358
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
351359
df = self.spark.read.json(rdd)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
2020
import org.apache.spark.sql.ExperimentalMethods
2121
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
2222
import org.apache.spark.sql.catalyst.optimizer.Optimizer
23+
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
2324
import org.apache.spark.sql.internal.SQLConf
2425

2526
class SparkOptimizer(
@@ -28,6 +29,7 @@ class SparkOptimizer(
2829
experimentalMethods: ExperimentalMethods)
2930
extends Optimizer(catalog, conf) {
3031

31-
override def batches: Seq[Batch] = super.batches :+ Batch(
32-
"User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
32+
override def batches: Seq[Batch] = super.batches :+
33+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
34+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
3335
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
4646

4747
def children: Seq[SparkPlan] = child :: Nil
4848

49+
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
50+
4951
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
5052
udf.children match {
5153
case Seq(u: PythonUDF) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,68 @@
1818
package org.apache.spark.sql.execution.python
1919

2020
import scala.collection.mutable
21+
import scala.collection.mutable.ArrayBuffer
2122

22-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
25+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
2326
import org.apache.spark.sql.catalyst.rules.Rule
2427
import org.apache.spark.sql.execution
2528
import org.apache.spark.sql.execution.SparkPlan
2629

30+
31+
/**
32+
* Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
33+
* grouping key, evaluate them after aggregate.
34+
*/
35+
private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
36+
37+
/**
38+
* Returns whether the expression could only be evaluated within aggregate.
39+
*/
40+
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
41+
e.isInstanceOf[AggregateExpression] ||
42+
agg.groupingExpressions.exists(_.semanticEquals(e))
43+
}
44+
45+
private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
46+
expr.find {
47+
e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
48+
}.isDefined
49+
}
50+
51+
private def extract(agg: Aggregate): LogicalPlan = {
52+
val projList = new ArrayBuffer[NamedExpression]()
53+
val aggExpr = new ArrayBuffer[NamedExpression]()
54+
agg.aggregateExpressions.foreach { expr =>
55+
if (hasPythonUdfOverAggregate(expr, agg)) {
56+
// Python UDF can only be evaluated after aggregate
57+
val newE = expr transformDown {
58+
case e: Expression if belongAggregate(e, agg) =>
59+
val alias = e match {
60+
case a: NamedExpression => a
61+
case o => Alias(e, "agg")()
62+
}
63+
aggExpr += alias
64+
alias.toAttribute
65+
}
66+
projList += newE.asInstanceOf[NamedExpression]
67+
} else {
68+
aggExpr += expr
69+
projList += expr.toAttribute
70+
}
71+
}
72+
// There is no Python UDF over aggregate expression
73+
Project(projList, agg.copy(aggregateExpressions = aggExpr))
74+
}
75+
76+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
77+
case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
78+
extract(agg)
79+
}
80+
}
81+
82+
2783
/**
2884
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
2985
* alone in a batch.
@@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
59115
}
60116

61117
/**
62-
* Extract all the PythonUDFs from the current operator.
118+
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
63119
*/
64-
def extract(plan: SparkPlan): SparkPlan = {
120+
private def extract(plan: SparkPlan): SparkPlan = {
65121
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
122+
// ignore the PythonUDF that come from second/third aggregate, which is not used
123+
.filter(udf => udf.references.subsetOf(plan.inputSet))
66124
if (udfs.isEmpty) {
67125
// If there aren't any, we are done.
68126
plan
@@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
89147
// Other cases are disallowed as they are ambiguous or would require a cartesian
90148
// product.
91149
udfs.filterNot(attributeMap.contains).foreach { udf =>
92-
if (udf.references.subsetOf(plan.inputSet)) {
93-
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
94-
} else {
95-
sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
96-
}
150+
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
97151
}
98152

99153
val rewritten = plan.transformExpressions {

0 commit comments

Comments
 (0)