Skip to content

Commit 8007fa6

Browse files
committed
common trait for grouped mandas udfs
1 parent e3b66ac commit 8007fa6

File tree

4 files changed

+156
-110
lines changed

4 files changed

+156
-110
lines changed

sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType}
4747
*/
4848
@Stable
4949
class RelationalGroupedDataset protected[sql](
50-
private val df: DataFrame,
51-
private val groupingExprs: Seq[Expression],
50+
val df: DataFrame,
51+
val groupingExprs: Seq[Expression],
5252
groupType: RelationalGroupedDataset.GroupType) {
5353

5454
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
@@ -542,11 +542,15 @@ class RelationalGroupedDataset protected[sql](
542542

543543
val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute)
544544
val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute)
545-
val left = df.logicalPlan
546-
val right = r.df.logicalPlan
545+
546+
val leftChild = df.logicalPlan
547+
val rightChild = r.df.logicalPlan
548+
549+
val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)
550+
val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)
551+
547552
val output = expr.dataType.asInstanceOf[StructType].toAttributes
548553
val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right)
549-
550554
Dataset.ofRows(df.sparkSession, plan)
551555
}
552556

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.python
18+
19+
import org.apache.spark.TaskContext
20+
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions}
21+
import org.apache.spark.rdd.RDD
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF, UnsafeProjection}
24+
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan}
25+
import org.apache.spark.sql.types.StructType
26+
import org.apache.spark.sql.util.ArrowUtils
27+
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
28+
29+
import scala.collection.mutable.ArrayBuffer
30+
import scala.collection.JavaConverters._
31+
32+
trait AbstractPandasGroupExec extends SparkPlan {
33+
34+
protected val sessionLocalTimeZone = conf.sessionLocalTimeZone
35+
36+
protected val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
37+
38+
protected def chainedFunc = Seq(
39+
ChainedPythonFunctions(Seq(func.asInstanceOf[PythonUDF].func)))
40+
41+
def output: Seq[Attribute]
42+
43+
def func: Expression
44+
45+
protected def executePython[T](data: Iterator[T],
46+
runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = {
47+
48+
val context = TaskContext.get()
49+
val columnarBatchIter = runner.compute(data, context.partitionId(), context)
50+
val unsafeProj = UnsafeProjection.create(output, output)
51+
52+
columnarBatchIter.flatMap { batch =>
53+
// UDF returns a StructType column in ColumnarBatch, select the children here
54+
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
55+
val outputVectors = output.indices.map(structVector.getChild)
56+
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
57+
flattenedBatch.setNumRows(batch.numRows())
58+
flattenedBatch.rowIterator.asScala
59+
}.map(unsafeProj)
60+
61+
}
62+
63+
protected def groupAndDedup(
64+
input: Iterator[InternalRow], groupingAttributes: Seq[Attribute],
65+
inputSchema: Seq[Attribute], dedupSchema: Seq[Attribute]): Iterator[Iterator[InternalRow]] = {
66+
if (groupingAttributes.isEmpty) {
67+
Iterator(input)
68+
} else {
69+
val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema)
70+
val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema)
71+
groupedIter.map {
72+
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
73+
}
74+
}
75+
}
76+
77+
protected def createSchema(child: SparkPlan, groupingAttributes: Seq[Attribute])
78+
: (StructType, Seq[Attribute], Array[Array[Int]]) = {
79+
80+
// Deduplicate the grouping attributes.
81+
// If a grouping attribute also appears in data attributes, then we don't need to send the
82+
// grouping attribute to Python worker. If a grouping attribute is not in data attributes,
83+
// then we need to send this grouping attribute to python worker.
84+
//
85+
// We use argOffsets to distinguish grouping attributes and data attributes as following:
86+
//
87+
// argOffsets[0] is the length of grouping attributes
88+
// argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
89+
// argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
90+
91+
val dataAttributes = child.output.drop(groupingAttributes.length)
92+
val groupingIndicesInData = groupingAttributes.map { attribute =>
93+
dataAttributes.indexWhere(attribute.semanticEquals)
94+
}
95+
96+
val groupingArgOffsets = new ArrayBuffer[Int]
97+
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
98+
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
99+
100+
// Non duplicate grouping attributes are added to nonDupGroupingAttributes and
101+
// their offsets are 0, 1, 2 ...
102+
// Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
103+
// their offsets are n + index, where n is the total number of non duplicate grouping
104+
// attributes and index is the index in the data attributes that the grouping attribute
105+
// is a duplicate of.
106+
107+
groupingAttributes.zip(groupingIndicesInData).foreach {
108+
case (attribute, index) =>
109+
if (index == -1) {
110+
groupingArgOffsets += nonDupGroupingAttributes.length
111+
nonDupGroupingAttributes += attribute
112+
} else {
113+
groupingArgOffsets += index + nonDupGroupingSize
114+
}
115+
}
116+
117+
val dataArgOffsets = nonDupGroupingAttributes.length until
118+
(nonDupGroupingAttributes.length + dataAttributes.length)
119+
120+
val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
121+
122+
// Attributes after deduplication
123+
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
124+
val dedupSchema = StructType.fromAttributes(dedupAttributes)
125+
(dedupSchema, dedupAttributes, argOffsets)
126+
}
127+
128+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,12 @@
1717

1818
package org.apache.spark.sql.execution.python
1919

20-
import scala.collection.JavaConverters._
21-
22-
import org.apache.spark.TaskContext
23-
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
20+
import org.apache.spark.api.python.PythonEvalType
2421
import org.apache.spark.rdd.RDD
2522
import org.apache.spark.sql.catalyst.InternalRow
2623
import org.apache.spark.sql.catalyst.expressions._
2724
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
2825
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, GroupedIterator, SparkPlan}
29-
import org.apache.spark.sql.util.ArrowUtils
30-
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
3126

3227
case class FlatMapCoGroupsInPandasExec(
3328
leftGroup: Seq[Attribute],
@@ -36,9 +31,7 @@ case class FlatMapCoGroupsInPandasExec(
3631
output: Seq[Attribute],
3732
left: SparkPlan,
3833
right: SparkPlan)
39-
extends BinaryExecNode {
40-
41-
private val pandasFunction = func.asInstanceOf[PythonUDF].func
34+
extends BinaryExecNode with AbstractPandasGroupExec {
4235

4336
override def outputPartitioning: Partitioning = left.outputPartitioning
4437

@@ -53,41 +46,30 @@ case class FlatMapCoGroupsInPandasExec(
5346
.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
5447
}
5548

56-
5749
override protected def doExecute(): RDD[InternalRow] = {
5850

59-
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
60-
val sessionLocalTimeZone = conf.sessionLocalTimeZone
61-
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
62-
51+
val (schemaLeft, attrLeft, _) = createSchema(left, leftGroup)
52+
val (schemaRight, attrRight, _) = createSchema(right, rightGroup)
6353

6454
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
6555
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
6656
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
67-
val cogroup = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
68-
.map{case (k, l, r) => (l, r)}
69-
val context = TaskContext.get()
57+
val projLeft = UnsafeProjection.create(attrLeft, left.output)
58+
val projRight = UnsafeProjection.create(attrRight, right.output)
59+
val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
60+
.map{case (k, l, r) => (l.map(projLeft), r.map(projRight))}
7061

71-
val columnarBatchIter = new InterleavedArrowPythonRunner(
62+
val runner = new InterleavedArrowPythonRunner(
7263
chainedFunc,
7364
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
7465
Array(Array.empty),
75-
left.schema,
76-
right.schema,
66+
schemaLeft,
67+
schemaRight,
7768
sessionLocalTimeZone,
78-
pythonRunnerConf).compute(cogroup, context.partitionId(), context)
79-
69+
pythonRunnerConf)
8070

81-
val unsafeProj = UnsafeProjection.create(output, output)
71+
executePython(data, runner)
8272

83-
columnarBatchIter.flatMap { batch =>
84-
// UDF returns a StructType column in ColumnarBatch, select the children here
85-
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
86-
val outputVectors = output.indices.map(structVector.getChild)
87-
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
88-
flattenedBatch.setNumRows(batch.numRows())
89-
flattenedBatch.rowIterator.asScala
90-
}.map(unsafeProj)
9173
}
9274

9375
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala

Lines changed: 6 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ case class FlatMapGroupsInPandasExec(
5353
func: Expression,
5454
output: Seq[Attribute],
5555
child: SparkPlan)
56-
extends UnaryExecNode {
57-
58-
private val pandasFunction = func.asInstanceOf[PythonUDF].func
56+
extends UnaryExecNode with AbstractPandasGroupExec {
5957

6058
override def outputPartitioning: Partitioning = child.outputPartitioning
6159

@@ -75,88 +73,22 @@ case class FlatMapGroupsInPandasExec(
7573
override protected def doExecute(): RDD[InternalRow] = {
7674
val inputRDD = child.execute()
7775

78-
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
79-
val sessionLocalTimeZone = conf.sessionLocalTimeZone
80-
val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
81-
82-
// Deduplicate the grouping attributes.
83-
// If a grouping attribute also appears in data attributes, then we don't need to send the
84-
// grouping attribute to Python worker. If a grouping attribute is not in data attributes,
85-
// then we need to send this grouping attribute to python worker.
86-
//
87-
// We use argOffsets to distinguish grouping attributes and data attributes as following:
88-
//
89-
// argOffsets[0] is the length of grouping attributes
90-
// argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
91-
// argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
92-
93-
val dataAttributes = child.output.drop(groupingAttributes.length)
94-
val groupingIndicesInData = groupingAttributes.map { attribute =>
95-
dataAttributes.indexWhere(attribute.semanticEquals)
96-
}
97-
98-
val groupingArgOffsets = new ArrayBuffer[Int]
99-
val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
100-
val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
101-
102-
// Non duplicate grouping attributes are added to nonDupGroupingAttributes and
103-
// their offsets are 0, 1, 2 ...
104-
// Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
105-
// their offsets are n + index, where n is the total number of non duplicate grouping
106-
// attributes and index is the index in the data attributes that the grouping attribute
107-
// is a duplicate of.
108-
109-
groupingAttributes.zip(groupingIndicesInData).foreach {
110-
case (attribute, index) =>
111-
if (index == -1) {
112-
groupingArgOffsets += nonDupGroupingAttributes.length
113-
nonDupGroupingAttributes += attribute
114-
} else {
115-
groupingArgOffsets += index + nonDupGroupingSize
116-
}
117-
}
118-
119-
val dataArgOffsets = nonDupGroupingAttributes.length until
120-
(nonDupGroupingAttributes.length + dataAttributes.length)
121-
122-
val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
123-
124-
// Attributes after deduplication
125-
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
126-
val dedupSchema = StructType.fromAttributes(dedupAttributes)
76+
val (dedupSchema, dedupAttributes, argOffsets) = createSchema(child, groupingAttributes)
12777

12878
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
12979
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
130-
val grouped = if (groupingAttributes.isEmpty) {
131-
Iterator(iter)
132-
} else {
133-
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
134-
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
135-
groupedIter.map {
136-
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
137-
}
138-
}
13980

140-
val context = TaskContext.get()
81+
val data = groupAndDedup(iter, groupingAttributes, child.output, dedupAttributes)
14182

142-
val columnarBatchIter = new ArrowPythonRunner(
83+
val runner = new ArrowPythonRunner(
14384
chainedFunc,
14485
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
14586
argOffsets,
14687
dedupSchema,
14788
sessionLocalTimeZone,
148-
pythonRunnerConf).compute(grouped, context.partitionId(), context)
149-
150-
val unsafeProj = UnsafeProjection.create(output, output)
89+
pythonRunnerConf)
15190

152-
columnarBatchIter.flatMap { batch =>
153-
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
154-
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
155-
val outputVectors = output.indices.map(structVector.getChild)
156-
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
157-
flattenedBatch.setNumRows(batch.numRows())
158-
flattenedBatch.rowIterator.asScala
159-
}.map(unsafeProj)
91+
executePython(data, runner)
16092
}}
16193
}
16294
}

0 commit comments

Comments
 (0)