Skip to content

Commit d4cf6d0

Browse files
committed
poc using arrow streams
1 parent 8007fa6 commit d4cf6d0

File tree

6 files changed

+62
-164
lines changed

6 files changed

+62
-164
lines changed

python/pyspark/serializers.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -359,19 +359,24 @@ def __repr__(self):
359359
class InterleavedArrowReader(object):
360360

361361
def __init__(self, stream):
362-
import pyarrow as pa
363-
self._schema1 = pa.read_schema(stream)
364-
self._schema2 = pa.read_schema(stream)
365-
self._reader = pa.MessageReader.open_stream(stream)
362+
self._stream = stream
366363

367364
def __iter__(self):
368365
return self
369366

370367
def __next__(self):
368+
stream_status = read_int(self._stream)
369+
if stream_status == SpecialLengths.START_ARROW_STREAM:
370+
return self._read_df(), self._read_df()
371+
elif stream_status == SpecialLengths.END_OF_DATA_SECTION:
372+
raise StopIteration
373+
else:
374+
raise ValueError('Received invalid stream status {0}'.format(stream_status))
375+
376+
def _read_df(self):
371377
import pyarrow as pa
372-
batch1 = pa.read_record_batch(self._reader.read_next_message(), self._schema1)
373-
batch2 = pa.read_record_batch(self._reader.read_next_message(), self._schema2)
374-
return batch1, batch2
378+
reader = pa.ipc.open_stream(self._stream)
379+
return [b for b in reader]
375380

376381

377382
class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
@@ -428,11 +433,11 @@ def load_stream(self, stream):
428433
"""
429434
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
430435
"""
431-
import pyarrow as pa
432-
reader = InterleavedArrowReader(pa.input_stream(stream))
436+
reader = InterleavedArrowReader(stream)
433437
for batch1, batch2 in reader:
434-
yield ( [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch1]).itercolumns()],
435-
[self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch2]).itercolumns()])
438+
import pyarrow as pa
439+
yield ([self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
440+
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()])
436441

437442

438443
class BatchedSerializer(Serializer):
Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import org.apache.spark.TaskContext
2020
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions}
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF, UnsafeProjection}
23+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, UnsafeProjection}
2424
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan}
2525
import org.apache.spark.sql.types.StructType
2626
import org.apache.spark.sql.util.ArrowUtils
@@ -29,18 +29,19 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
2929
import scala.collection.mutable.ArrayBuffer
3030
import scala.collection.JavaConverters._
3131

32-
trait AbstractPandasGroupExec extends SparkPlan {
32+
abstract class BasePandasGroupExec(func: Expression,
33+
output: Seq[Attribute]) extends SparkPlan {
3334

3435
protected val sessionLocalTimeZone = conf.sessionLocalTimeZone
3536

3637
protected val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
3738

38-
protected def chainedFunc = Seq(
39-
ChainedPythonFunctions(Seq(func.asInstanceOf[PythonUDF].func)))
39+
protected val pandasFunction = func.asInstanceOf[PythonUDF].func
4040

41-
def output: Seq[Attribute]
41+
protected val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
42+
43+
override def producedAttributes: AttributeSet = AttributeSet(output)
4244

43-
def func: Expression
4445

4546
protected def executePython[T](data: Iterator[T],
4647
runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = {
@@ -62,16 +63,12 @@ trait AbstractPandasGroupExec extends SparkPlan {
6263

6364
protected def groupAndDedup(
6465
input: Iterator[InternalRow], groupingAttributes: Seq[Attribute],
65-
inputSchema: Seq[Attribute], dedupSchema: Seq[Attribute]): Iterator[Iterator[InternalRow]] = {
66-
if (groupingAttributes.isEmpty) {
67-
Iterator(input)
68-
} else {
69-
val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema)
70-
val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema)
71-
groupedIter.map {
72-
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
73-
}
74-
}
66+
inputSchema: Seq[Attribute], dedupSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
67+
val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema)
68+
val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema)
69+
groupedIter.map {
70+
case (k, groupedRowIter) => (k, groupedRowIter.map(dedupProj))
71+
}
7572
}
7673

7774
protected def createSchema(child: SparkPlan, groupingAttributes: Seq[Attribute])

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@ case class FlatMapCoGroupsInPandasExec(
3131
output: Seq[Attribute],
3232
left: SparkPlan,
3333
right: SparkPlan)
34-
extends BinaryExecNode with AbstractPandasGroupExec {
34+
extends BasePandasGroupExec(func, output) with BinaryExecNode{
3535

3636
override def outputPartitioning: Partitioning = left.outputPartitioning
3737

38-
override def producedAttributes: AttributeSet = AttributeSet(output)
39-
4038
override def requiredChildDistribution: Seq[Distribution] = {
4139
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
4240
}
@@ -48,16 +46,15 @@ case class FlatMapCoGroupsInPandasExec(
4846

4947
override protected def doExecute(): RDD[InternalRow] = {
5048

51-
val (schemaLeft, attrLeft, _) = createSchema(left, leftGroup)
52-
val (schemaRight, attrRight, _) = createSchema(right, rightGroup)
49+
val (schemaLeft, leftDedup, _) = createSchema(left, leftGroup)
50+
val (schemaRight, rightDedup, _) = createSchema(right, rightGroup)
5351

5452
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
55-
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
56-
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
57-
val projLeft = UnsafeProjection.create(attrLeft, left.output)
58-
val projRight = UnsafeProjection.create(attrRight, right.output)
53+
54+
val leftGrouped = groupAndDedup(leftData, leftGroup, left.output, leftDedup)
55+
val rightGrouped = groupAndDedup(rightData, rightGroup, right.output, rightDedup)
5956
val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
60-
.map{case (k, l, r) => (l.map(projLeft), r.map(projRight))}
57+
.map{case (k, l, r) => (l, r)}
6158

6259
val runner = new InterleavedArrowPythonRunner(
6360
chainedFunc,

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,10 @@ case class FlatMapGroupsInPandasExec(
5353
func: Expression,
5454
output: Seq[Attribute],
5555
child: SparkPlan)
56-
extends UnaryExecNode with AbstractPandasGroupExec {
56+
extends BasePandasGroupExec(func, output) with UnaryExecNode {
5757

5858
override def outputPartitioning: Partitioning = child.outputPartitioning
5959

60-
override def producedAttributes: AttributeSet = AttributeSet(output)
61-
6260
override def requiredChildDistribution: Seq[Distribution] = {
6361
if (groupingAttributes.isEmpty) {
6462
AllTuples :: Nil
@@ -79,6 +77,7 @@ case class FlatMapGroupsInPandasExec(
7977
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
8078

8179
val data = groupAndDedup(iter, groupingAttributes, child.output, dedupAttributes)
80+
.map{case(_, x) => x}
8281

8382
val runner = new ArrowPythonRunner(
8483
chainedFunc,

sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import java.io._
2121
import java.net._
2222

2323
import org.apache.arrow.vector.VectorSchemaRoot
24-
24+
import org.apache.arrow.vector.dictionary.DictionaryProvider
25+
import org.apache.arrow.vector.ipc.ArrowStreamWriter
2526
import org.apache.spark._
2627
import org.apache.spark.api.python._
2728
import org.apache.spark.sql.catalyst.InternalRow
@@ -64,55 +65,39 @@ class InterleavedArrowPythonRunner(
6465
}
6566

6667
protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
67-
val leftArrowSchema = ArrowUtils.toArrowSchema(leftSchema, timeZoneId)
68-
val rightArrowSchema = ArrowUtils.toArrowSchema(rightSchema, timeZoneId)
68+
while (inputIterator.hasNext) {
69+
dataOut.writeInt(SpecialLengths.START_ARROW_STREAM)
70+
val (nextLeft, nextRight) = inputIterator.next()
71+
writeGroup(nextLeft, leftSchema, dataOut)
72+
writeGroup(nextRight, rightSchema, dataOut)
73+
}
74+
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
75+
}
76+
77+
def writeGroup(group: Iterator[InternalRow], schema: StructType, dataOut: DataOutputStream
78+
) = {
79+
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
6980
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
7081
s"stdout writer for $pythonExec", 0, Long.MaxValue)
71-
val leftRoot = VectorSchemaRoot.create(leftArrowSchema, allocator)
72-
val rightRoot = VectorSchemaRoot.create(rightArrowSchema, allocator)
82+
val root = VectorSchemaRoot.create(arrowSchema, allocator)
7383

7484
Utils.tryWithSafeFinally {
75-
val leftArrowWriter = ArrowWriter.create(leftRoot)
76-
val rightArrowWriter = ArrowWriter.create(rightRoot)
77-
val writer = InterleavedArrowWriter(leftRoot, rightRoot, dataOut)
85+
val writer = new ArrowStreamWriter(root, null, dataOut)
86+
val arrowWriter = ArrowWriter.create(root)
7887
writer.start()
7988

80-
while (inputIterator.hasNext) {
81-
82-
val (nextLeft, nextRight) = inputIterator.next()
83-
84-
while (nextLeft.hasNext) {
85-
leftArrowWriter.write(nextLeft.next())
86-
}
87-
while (nextRight.hasNext) {
88-
rightArrowWriter.write(nextRight.next())
89-
}
90-
leftArrowWriter.finish()
91-
rightArrowWriter.finish()
92-
writer.writeBatch()
93-
leftArrowWriter.reset()
94-
rightArrowWriter.reset()
89+
while (group.hasNext) {
90+
arrowWriter.write(group.next())
9591
}
96-
// end writes footer to the output stream and doesn't clean any resources.
97-
// It could throw exception if the output stream is closed, so it should be
98-
// in the try block.
92+
arrowWriter.finish()
93+
writer.writeBatch()
9994
writer.end()
100-
} {
101-
// If we close root and allocator in TaskCompletionListener, there could be a race
102-
// condition where the writer thread keeps writing to the VectorSchemaRoot while
103-
// it's being closed by the TaskCompletion listener.
104-
// Closing root and allocator here is cleaner because root and allocator is owned
105-
// by the writer thread and is only visible to the writer thread.
106-
//
107-
// If the writer thread is interrupted by TaskCompletionListener, it should either
108-
// (1) in the try block, in which case it will get an InterruptedException when
109-
// performing io, and goes into the finally block or (2) in the finally block,
110-
// in which case it will ignore the interruption and close the resources.
111-
leftRoot.close()
112-
rightRoot.close()
95+
}{
96+
root.close()
11397
allocator.close()
11498
}
11599
}
116100
}
117101
}
118102
}
103+

sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala

Lines changed: 0 additions & 85 deletions
This file was deleted.

0 commit comments

Comments
 (0)