@@ -20,13 +20,31 @@ package org.apache.spark.sql.execution.streaming.continuous
2020import java .util .UUID
2121
2222import org .apache .spark ._
23- import org .apache .spark .rdd .{ CoalescedRDDPartition , RDD }
23+ import org .apache .spark .rdd .RDD
2424import org .apache .spark .sql .catalyst .InternalRow
2525import org .apache .spark .sql .catalyst .expressions .UnsafeRow
2626import org .apache .spark .sql .execution .streaming .continuous .shuffle ._
2727import org .apache .spark .util .ThreadUtils
2828
29- case class ContinuousCoalesceRDDPartition (index : Int ) extends Partition {
29+ case class ContinuousCoalesceRDDPartition (
30+ index : Int ,
31+ endpointName : String ,
32+ queueSize : Int ,
33+ numShuffleWriters : Int ,
34+ epochIntervalMs : Long )
35+ extends Partition {
36+ // Initialized only on the executor, and only once even as we call compute() multiple times.
37+ lazy val (reader : ContinuousShuffleReader , endpoint) = {
38+ val env = SparkEnv .get.rpcEnv
39+ val receiver = new RPCContinuousShuffleReader (
40+ queueSize, numShuffleWriters, epochIntervalMs, env)
41+ val endpoint = env.setupEndpoint(endpointName, receiver)
42+
43+ TaskContext .get().addTaskCompletionListener { ctx =>
44+ env.stop(endpoint)
45+ }
46+ (receiver, endpoint)
47+ }
3048 // This flag will be flipped on the executors to indicate that the threads processing
3149 // partitions of the write-side RDD have been started. These will run indefinitely
3250 // asynchronously as epochs of the coalesce RDD complete on the read side.
@@ -45,9 +63,6 @@ class ContinuousCoalesceRDD(
4563 prev : RDD [InternalRow ])
4664 extends RDD [InternalRow ](context, Nil ) {
4765
48- override def getPartitions : Array [Partition ] =
49- (0 until numPartitions).map(ContinuousCoalesceRDDPartition ).toArray
50-
5166 // When we support more than 1 target partition, we'll need to figure out how to pass in the
5267 // required partitioner.
5368 private val outputPartitioner = new HashPartitioner (1 )
@@ -56,27 +71,30 @@ class ContinuousCoalesceRDD(
5671 s " ContinuousCoalesceRDD-part $i- ${UUID .randomUUID()}"
5772 }
5873
59- val readerRDD = new ContinuousShuffleReadRDD (
60- sparkContext,
61- numPartitions,
62- readerQueueSize,
63- prev.getNumPartitions,
64- epochIntervalMs,
65- readerEndpointNames)
74+ override def getPartitions : Array [Partition ] = {
75+ (0 until numPartitions).map { partIndex =>
76+ ContinuousCoalesceRDDPartition (
77+ partIndex,
78+ readerEndpointNames(partIndex),
79+ readerQueueSize,
80+ prev.getNumPartitions,
81+ epochIntervalMs)
82+ }.toArray
83+ }
6684
6785 private lazy val threadPool = ThreadUtils .newDaemonFixedThreadPool(
6886 prev.getNumPartitions,
6987 this .name)
7088
7189 override def compute (split : Partition , context : TaskContext ): Iterator [InternalRow ] = {
72- // lazy initialize endpoints so writer can send to them
73- readerRDD.partitions.foreach {
74- _.asInstanceOf [ContinuousShuffleReadPartition ].endpoint
75- }
90+ val part = split.asInstanceOf [ContinuousCoalesceRDDPartition ]
7691
77- if (! split. asInstanceOf [ ContinuousCoalesceRDDPartition ] .writersInitialized) {
92+ if (! part .writersInitialized) {
7893 val rpcEnv = SparkEnv .get.rpcEnv
79- val endpointRefs = readerRDD.endpointNames.map { endpointName =>
94+
95+ // trigger lazy initialization
96+ part.endpoint
97+ val endpointRefs = readerEndpointNames.map { endpointName =>
8098 rpcEnv.setupEndpointRef(rpcEnv.address, endpointName)
8199 }
82100
@@ -104,12 +122,12 @@ class ContinuousCoalesceRDD(
104122 threadPool.shutdownNow()
105123 }
106124
107- split. asInstanceOf [ ContinuousCoalesceRDDPartition ] .writersInitialized = true
125+ part .writersInitialized = true
108126
109127 runnables.foreach(threadPool.execute)
110128 }
111129
112- readerRDD.compute(readerRDD.partitions(split.index), context )
130+ part.reader.read( )
113131 }
114132
115133 override def clearDependencies (): Unit = {
0 commit comments