18
18
package org .apache .spark .sql .execution .python
19
19
20
20
import scala .collection .mutable
21
+ import scala .collection .mutable .ArrayBuffer
21
22
22
- import org .apache .spark .sql .catalyst .expressions .{AttributeReference , Expression }
23
+ import org .apache .spark .sql .catalyst .expressions ._
24
+ import org .apache .spark .sql .catalyst .expressions .aggregate .AggregateExpression
25
+ import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , LogicalPlan , Project }
23
26
import org .apache .spark .sql .catalyst .rules .Rule
24
27
import org .apache .spark .sql .execution
25
28
import org .apache .spark .sql .execution .SparkPlan
26
29
30
+
31
+ /**
32
+ * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
33
+ * grouping key, evaluate them after aggregate.
34
+ */
35
+ private [spark] object ExtractPythonUDFFromAggregate extends Rule [LogicalPlan ] {
36
+
37
+ /**
38
+ * Returns whether the expression could only be evaluated within aggregate.
39
+ */
40
+ private def belongAggregate (e : Expression , agg : Aggregate ): Boolean = {
41
+ e.isInstanceOf [AggregateExpression ] ||
42
+ agg.groupingExpressions.exists(_.semanticEquals(e))
43
+ }
44
+
45
+ private def hasPythonUdfOverAggregate (expr : Expression , agg : Aggregate ): Boolean = {
46
+ expr.find {
47
+ e => e.isInstanceOf [PythonUDF ] && e.find(belongAggregate(_, agg)).isDefined
48
+ }.isDefined
49
+ }
50
+
51
+ private def extract (agg : Aggregate ): LogicalPlan = {
52
+ val projList = new ArrayBuffer [NamedExpression ]()
53
+ val aggExpr = new ArrayBuffer [NamedExpression ]()
54
+ agg.aggregateExpressions.foreach { expr =>
55
+ if (hasPythonUdfOverAggregate(expr, agg)) {
56
+ // Python UDF can only be evaluated after aggregate
57
+ val newE = expr transformDown {
58
+ case e : Expression if belongAggregate(e, agg) =>
59
+ val alias = e match {
60
+ case a : NamedExpression => a
61
+ case o => Alias (e, " agg" )()
62
+ }
63
+ aggExpr += alias
64
+ alias.toAttribute
65
+ }
66
+ projList += newE.asInstanceOf [NamedExpression ]
67
+ } else {
68
+ aggExpr += expr
69
+ projList += expr.toAttribute
70
+ }
71
+ }
72
+ // There is no Python UDF over aggregate expression
73
+ Project (projList, agg.copy(aggregateExpressions = aggExpr))
74
+ }
75
+
76
+ def apply (plan : LogicalPlan ): LogicalPlan = plan transformUp {
77
+ case agg : Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
78
+ extract(agg)
79
+ }
80
+ }
81
+
82
+
27
83
/**
28
84
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
29
85
* alone in a batch.
@@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
59
115
}
60
116
61
117
/**
62
- * Extract all the PythonUDFs from the current operator.
118
+ * Extract all the PythonUDFs from the current operator and evaluate them before the operator .
63
119
*/
64
- def extract (plan : SparkPlan ): SparkPlan = {
120
+ private def extract (plan : SparkPlan ): SparkPlan = {
65
121
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
122
+ // ignore the PythonUDF that come from second/third aggregate, which is not used
123
+ .filter(udf => udf.references.subsetOf(plan.inputSet))
66
124
if (udfs.isEmpty) {
67
125
// If there aren't any, we are done.
68
126
plan
@@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
89
147
// Other cases are disallowed as they are ambiguous or would require a cartesian
90
148
// product.
91
149
udfs.filterNot(attributeMap.contains).foreach { udf =>
92
- if (udf.references.subsetOf(plan.inputSet)) {
93
- sys.error(s " Invalid PythonUDF $udf, requires attributes from more than one child. " )
94
- } else {
95
- sys.error(s " Unable to evaluate PythonUDF $udf. Missing input attributes. " )
96
- }
150
+ sys.error(s " Invalid PythonUDF $udf, requires attributes from more than one child. " )
97
151
}
98
152
99
153
val rewritten = plan.transformExpressions {
0 commit comments