Skip to content

Commit 468f134

Browse files
committed
fixes
1 parent 0b35766 commit 468f134

File tree

6 files changed

+45
-35
lines changed

6 files changed

+45
-35
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.sql.{sources, Strategy}
23-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression}
23+
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
2424
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
2525
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
2626
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,11 @@ case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends
4141

4242
override def doExecute(): RDD[InternalRow] = {
4343
assert(numPartitions == 1)
44-
45-
val childRdd = child.execute()
46-
4744
new ContinuousCoalesceRDD(
4845
sparkContext,
4946
numPartitions,
5047
conf.continuousStreamingExecutorQueueSize,
5148
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong,
52-
childRdd)
49+
child.execute())
5350
}
5451
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,31 @@ package org.apache.spark.sql.execution.streaming.continuous
2020
import java.util.UUID
2121

2222
import org.apache.spark._
23-
import org.apache.spark.rdd.{CoalescedRDDPartition, RDD}
23+
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
2626
import org.apache.spark.sql.execution.streaming.continuous.shuffle._
2727
import org.apache.spark.util.ThreadUtils
2828

29-
case class ContinuousCoalesceRDDPartition(index: Int) extends Partition {
29+
case class ContinuousCoalesceRDDPartition(
30+
index: Int,
31+
endpointName: String,
32+
queueSize: Int,
33+
numShuffleWriters: Int,
34+
epochIntervalMs: Long)
35+
extends Partition {
36+
// Initialized only on the executor, and only once even as we call compute() multiple times.
37+
lazy val (reader: ContinuousShuffleReader, endpoint) = {
38+
val env = SparkEnv.get.rpcEnv
39+
val receiver = new RPCContinuousShuffleReader(
40+
queueSize, numShuffleWriters, epochIntervalMs, env)
41+
val endpoint = env.setupEndpoint(endpointName, receiver)
42+
43+
TaskContext.get().addTaskCompletionListener { ctx =>
44+
env.stop(endpoint)
45+
}
46+
(receiver, endpoint)
47+
}
3048
// This flag will be flipped on the executors to indicate that the threads processing
3149
// partitions of the write-side RDD have been started. These will run indefinitely
3250
// asynchronously as epochs of the coalesce RDD complete on the read side.
@@ -45,9 +63,6 @@ class ContinuousCoalesceRDD(
4563
prev: RDD[InternalRow])
4664
extends RDD[InternalRow](context, Nil) {
4765

48-
override def getPartitions: Array[Partition] =
49-
(0 until numPartitions).map(ContinuousCoalesceRDDPartition).toArray
50-
5166
// When we support more than 1 target partition, we'll need to figure out how to pass in the
5267
// required partitioner.
5368
private val outputPartitioner = new HashPartitioner(1)
@@ -56,27 +71,30 @@ class ContinuousCoalesceRDD(
5671
s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}"
5772
}
5873

59-
val readerRDD = new ContinuousShuffleReadRDD(
60-
sparkContext,
61-
numPartitions,
62-
readerQueueSize,
63-
prev.getNumPartitions,
64-
epochIntervalMs,
65-
readerEndpointNames)
74+
override def getPartitions: Array[Partition] = {
75+
(0 until numPartitions).map { partIndex =>
76+
ContinuousCoalesceRDDPartition(
77+
partIndex,
78+
readerEndpointNames(partIndex),
79+
readerQueueSize,
80+
prev.getNumPartitions,
81+
epochIntervalMs)
82+
}.toArray
83+
}
6684

6785
private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool(
6886
prev.getNumPartitions,
6987
this.name)
7088

7189
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
72-
// lazy initialize endpoints so writer can send to them
73-
readerRDD.partitions.foreach {
74-
_.asInstanceOf[ContinuousShuffleReadPartition].endpoint
75-
}
90+
val part = split.asInstanceOf[ContinuousCoalesceRDDPartition]
7691

77-
if (!split.asInstanceOf[ContinuousCoalesceRDDPartition].writersInitialized) {
92+
if (!part.writersInitialized) {
7893
val rpcEnv = SparkEnv.get.rpcEnv
79-
val endpointRefs = readerRDD.endpointNames.map { endpointName =>
94+
95+
// trigger lazy initialization
96+
part.endpoint
97+
val endpointRefs = readerEndpointNames.map { endpointName =>
8098
rpcEnv.setupEndpointRef(rpcEnv.address, endpointName)
8199
}
82100

@@ -104,12 +122,12 @@ class ContinuousCoalesceRDD(
104122
threadPool.shutdownNow()
105123
}
106124

107-
split.asInstanceOf[ContinuousCoalesceRDDPartition].writersInitialized = true
125+
part.writersInitialized = true
108126

109127
runnables.foreach(threadPool.execute)
110128
}
111129

112-
readerRDD.compute(readerRDD.partitions(split.index), context)
130+
part.reader.read()
113131
}
114132

115133
override def clearDependencies(): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ class ContinuousDataSourceRDD(
5151
sc: SparkContext,
5252
dataQueueSize: Int,
5353
epochPollIntervalMs: Long,
54-
private val readerFactories: Seq[InputPartition[UnsafeRow]])
54+
private val readerInputPartitions: Seq[InputPartition[UnsafeRow]])
5555
extends RDD[UnsafeRow](sc, Nil) {
5656

5757
override protected def getPartitions: Array[Partition] = {
58-
readerFactories.zipWithIndex.map {
58+
readerInputPartitions.zipWithIndex.map {
5959
case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition)
6060
}.toArray
6161
}
@@ -74,8 +74,7 @@ class ContinuousDataSourceRDD(
7474
val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition]
7575
if (partition.queueReader == null) {
7676
partition.queueReader =
77-
new ContinuousQueuedDataReader(
78-
partition, context, dataQueueSize, epochPollIntervalMs)
77+
new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs)
7978
}
8079

8180
partition.queueReader
@@ -98,10 +97,6 @@ class ContinuousDataSourceRDD(
9897
override def getPreferredLocations(split: Partition): Seq[String] = {
9998
split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations()
10099
}
101-
102-
override def clearDependencies(): Unit = {
103-
throw new IllegalStateException("Continuous RDDs cannot be checkpointed")
104-
}
105100
}
106101

107102
object ContinuousDataSourceRDD {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContin
4646
* TODO: Support multiple source tasks. We need to output a single epoch marker once all
4747
* source tasks have sent one.
4848
*/
49-
private[shuffle] class RPCContinuousShuffleReader(
49+
private[continuous] class RPCContinuousShuffleReader(
5050
queueSize: Int,
5151
numShuffleWriters: Int,
5252
epochIntervalMs: Long,

sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class ContinuousShuffleSuite extends StreamTest {
146146
val iter = rdd.compute(part, ctx)
147147
assert(iter.next().getInt(0) == part.index)
148148
assert(!iter.hasNext)
149-
}Oh
149+
}
150150
}
151151

152152
test("reader - blocks waiting for new rows") {

0 commit comments

Comments
 (0)