Skip to content

Commit f6e6899

Browse files
jose-torrestdas
authored andcommitted
[SPARK-24386][SS] coalesce(1) aggregates in continuous processing
## What changes were proposed in this pull request? Provide a continuous processing implementation of coalesce(1), as well as allowing aggregates on top of it. The changes in ContinuousQueuedDataReader and such are to use split.index (the ID of the partition within the RDD currently being compute()d) rather than context.partitionId() (the partition ID of the scheduled task within the Spark job - that is, the post coalesce writer). In the absence of a narrow dependency, these values were previously always the same, so there was no need to distinguish. ## How was this patch tested? new unit test Author: Jose Torres <torres.joseph.f+github@gmail.com> Closes #21560 from jose-torres/coalesce.
1 parent 2224861 commit f6e6899

File tree

13 files changed

+310
-18
lines changed

13 files changed

+310
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,17 @@ object UnsupportedOperationChecker {
349349
_: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias |
350350
_: TypedFilter) =>
351351
case node if node.nodeName == "StreamingRelationV2" =>
352+
case Repartition(1, false, _) =>
353+
case node: Aggregate =>
354+
val aboveSinglePartitionCoalesce = node.find {
355+
case Repartition(1, false, _) => true
356+
case _ => false
357+
}.isDefined
358+
359+
if (!aboveSinglePartitionCoalesce) {
360+
throwError(s"In continuous processing mode, coalesce(1) must be called before " +
361+
s"aggregate operation ${node.nodeName}.")
362+
}
352363
case node =>
353364
throwError(s"Continuous processing does not support ${node.nodeName} operations.")
354365
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ import scala.collection.mutable
2222
import org.apache.spark.sql.{sources, Strategy}
2323
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
2424
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
25-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
25+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
2626
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
2727
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
28-
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
28+
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
2929
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
30+
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
3031

3132
object DataSourceV2Strategy extends Strategy {
3233

@@ -141,6 +142,17 @@ object DataSourceV2Strategy extends Strategy {
141142
case WriteToContinuousDataSource(writer, query) =>
142143
WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil
143144

145+
case Repartition(1, false, child) =>
146+
val isContinuous = child.collectFirst {
147+
case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r
148+
}.isDefined
149+
150+
if (isContinuous) {
151+
ContinuousCoalesceExec(1, planLater(child)) :: Nil
152+
} else {
153+
Nil
154+
}
155+
144156
case _ => Nil
145157
}
146158
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.continuous
19+
20+
import java.util.UUID
21+
22+
import org.apache.spark.{HashPartitioner, SparkEnv}
23+
import org.apache.spark.rdd.RDD
24+
import org.apache.spark.sql.catalyst.InternalRow
25+
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
26+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition}
27+
import org.apache.spark.sql.execution.SparkPlan
28+
import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD}
29+
30+
/**
31+
* Physical plan for coalescing a continuous processing plan.
32+
*
33+
* Currently, only coalesces to a single partition are supported. `numPartitions` must be 1.
34+
*/
35+
case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan {
36+
override def output: Seq[Attribute] = child.output
37+
38+
override def children: Seq[SparkPlan] = child :: Nil
39+
40+
override def outputPartitioning: Partitioning = SinglePartition
41+
42+
override def doExecute(): RDD[InternalRow] = {
43+
assert(numPartitions == 1)
44+
new ContinuousCoalesceRDD(
45+
sparkContext,
46+
numPartitions,
47+
conf.continuousStreamingExecutorQueueSize,
48+
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong,
49+
child.execute())
50+
}
51+
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming.continuous
19+
20+
import java.util.UUID
21+
22+
import org.apache.spark._
23+
import org.apache.spark.rdd.RDD
24+
import org.apache.spark.sql.catalyst.InternalRow
25+
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
26+
import org.apache.spark.sql.execution.streaming.continuous.shuffle._
27+
import org.apache.spark.util.ThreadUtils
28+
29+
case class ContinuousCoalesceRDDPartition(
30+
index: Int,
31+
endpointName: String,
32+
queueSize: Int,
33+
numShuffleWriters: Int,
34+
epochIntervalMs: Long)
35+
extends Partition {
36+
// Initialized only on the executor, and only once even as we call compute() multiple times.
37+
lazy val (reader: ContinuousShuffleReader, endpoint) = {
38+
val env = SparkEnv.get.rpcEnv
39+
val receiver = new RPCContinuousShuffleReader(
40+
queueSize, numShuffleWriters, epochIntervalMs, env)
41+
val endpoint = env.setupEndpoint(endpointName, receiver)
42+
43+
TaskContext.get().addTaskCompletionListener { ctx =>
44+
env.stop(endpoint)
45+
}
46+
(receiver, endpoint)
47+
}
48+
// This flag will be flipped on the executors to indicate that the threads processing
49+
// partitions of the write-side RDD have been started. These will run indefinitely
50+
// asynchronously as epochs of the coalesce RDD complete on the read side.
51+
private[continuous] var writersInitialized: Boolean = false
52+
}
53+
54+
/**
55+
* RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local
56+
* continuous shuffle, and then reads them in the task thread using `reader`.
57+
*/
58+
class ContinuousCoalesceRDD(
59+
context: SparkContext,
60+
numPartitions: Int,
61+
readerQueueSize: Int,
62+
epochIntervalMs: Long,
63+
prev: RDD[InternalRow])
64+
extends RDD[InternalRow](context, Nil) {
65+
66+
// When we support more than 1 target partition, we'll need to figure out how to pass in the
67+
// required partitioner.
68+
private val outputPartitioner = new HashPartitioner(1)
69+
70+
private val readerEndpointNames = (0 until numPartitions).map { i =>
71+
s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}"
72+
}
73+
74+
override def getPartitions: Array[Partition] = {
75+
(0 until numPartitions).map { partIndex =>
76+
ContinuousCoalesceRDDPartition(
77+
partIndex,
78+
readerEndpointNames(partIndex),
79+
readerQueueSize,
80+
prev.getNumPartitions,
81+
epochIntervalMs)
82+
}.toArray
83+
}
84+
85+
private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool(
86+
prev.getNumPartitions,
87+
this.name)
88+
89+
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
90+
val part = split.asInstanceOf[ContinuousCoalesceRDDPartition]
91+
92+
if (!part.writersInitialized) {
93+
val rpcEnv = SparkEnv.get.rpcEnv
94+
95+
// trigger lazy initialization
96+
part.endpoint
97+
val endpointRefs = readerEndpointNames.map { endpointName =>
98+
rpcEnv.setupEndpointRef(rpcEnv.address, endpointName)
99+
}
100+
101+
val runnables = prev.partitions.map { prevSplit =>
102+
new Runnable() {
103+
override def run(): Unit = {
104+
TaskContext.setTaskContext(context)
105+
106+
val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter(
107+
prevSplit.index, outputPartitioner, endpointRefs.toArray)
108+
109+
EpochTracker.initializeCurrentEpoch(
110+
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
111+
while (!context.isInterrupted() && !context.isCompleted()) {
112+
writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]])
113+
// Note that current epoch is a non-inheritable thread local, so each writer thread
114+
// can properly increment its own epoch without affecting the main task thread.
115+
EpochTracker.incrementCurrentEpoch()
116+
}
117+
}
118+
}
119+
}
120+
121+
context.addTaskCompletionListener { ctx =>
122+
threadPool.shutdownNow()
123+
}
124+
125+
part.writersInitialized = true
126+
127+
runnables.foreach(threadPool.execute)
128+
}
129+
130+
part.reader.read()
131+
}
132+
133+
override def clearDependencies(): Unit = {
134+
throw new IllegalStateException("Continuous RDDs cannot be checkpointed")
135+
}
136+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ class ContinuousDataSourceRDD(
5151
sc: SparkContext,
5252
dataQueueSize: Int,
5353
epochPollIntervalMs: Long,
54-
@transient private val readerFactories: Seq[InputPartition[UnsafeRow]])
54+
private val readerInputPartitions: Seq[InputPartition[UnsafeRow]])
5555
extends RDD[UnsafeRow](sc, Nil) {
5656

5757
override protected def getPartitions: Array[Partition] = {
58-
readerFactories.zipWithIndex.map {
58+
readerInputPartitions.zipWithIndex.map {
5959
case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition)
6060
}.toArray
6161
}
@@ -74,8 +74,7 @@ class ContinuousDataSourceRDD(
7474
val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition]
7575
if (partition.queueReader == null) {
7676
partition.queueReader =
77-
new ContinuousQueuedDataReader(
78-
partition.inputPartition, context, dataQueueSize, epochPollIntervalMs)
77+
new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs)
7978
}
8079

8180
partition.queueReader

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ class ContinuousExecution(
216216
currentEpochCoordinatorId = epochCoordinatorId
217217
sparkSessionForQuery.sparkContext.setLocalProperty(
218218
ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId)
219+
sparkSessionForQuery.sparkContext.setLocalProperty(
220+
ContinuousExecution.EPOCH_INTERVAL_KEY,
221+
trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString)
219222

220223
// Use the parent Spark session for the endpoint since it's where this query ID is registered.
221224
val epochEndpoint =
@@ -382,4 +385,5 @@ class ContinuousExecution(
382385
object ContinuousExecution {
383386
val START_EPOCH_KEY = "__continuous_start_epoch"
384387
val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id"
388+
val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval"
385389
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ import org.apache.spark.util.ThreadUtils
3737
* offsets across epochs. Each compute() should call the next() method here until null is returned.
3838
*/
3939
class ContinuousQueuedDataReader(
40-
partition: InputPartition[UnsafeRow],
40+
partition: ContinuousDataSourceRDDPartition,
4141
context: TaskContext,
4242
dataQueueSize: Int,
4343
epochPollIntervalMs: Long) extends Closeable {
44-
private val reader = partition.createPartitionReader()
44+
private val reader = partition.inputPartition.createPartitionReader()
4545

4646
// Important sequencing - we must get our starting point before the provider threads start running
4747
private var currentOffset: PartitionOffset =
@@ -113,7 +113,7 @@ class ContinuousQueuedDataReader(
113113
currentEntry match {
114114
case EpochMarker =>
115115
epochCoordEndpoint.send(ReportPartitionOffset(
116-
context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset))
116+
partition.index, EpochTracker.getCurrentEpoch.get, currentOffset))
117117
null
118118
case ContinuousRow(row, offset) =>
119119
currentOffset = offset

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ import java.util.UUID
2121

2222
import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
2323
import org.apache.spark.rdd.RDD
24+
import org.apache.spark.rpc.RpcAddress
2425
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
2526
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.util.NextIterator
2728

2829
case class ContinuousShuffleReadPartition(
2930
index: Int,
31+
endpointName: String,
3032
queueSize: Int,
3133
numShuffleWriters: Int,
3234
epochIntervalMs: Long)
@@ -36,7 +38,7 @@ case class ContinuousShuffleReadPartition(
3638
val env = SparkEnv.get.rpcEnv
3739
val receiver = new RPCContinuousShuffleReader(
3840
queueSize, numShuffleWriters, epochIntervalMs, env)
39-
val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver)
41+
val endpoint = env.setupEndpoint(endpointName, receiver)
4042

4143
TaskContext.get().addTaskCompletionListener { ctx =>
4244
env.stop(endpoint)
@@ -61,12 +63,14 @@ class ContinuousShuffleReadRDD(
6163
numPartitions: Int,
6264
queueSize: Int = 1024,
6365
numShuffleWriters: Int = 1,
64-
epochIntervalMs: Long = 1000)
66+
epochIntervalMs: Long = 1000,
67+
val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}"))
6568
extends RDD[UnsafeRow](sc, Nil) {
6669

6770
override protected def getPartitions: Array[Partition] = {
6871
(0 until numPartitions).map { partIndex =>
69-
ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs)
72+
ContinuousShuffleReadPartition(
73+
partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs)
7074
}.toArray
7175
}
7276

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContin
4646
* TODO: Support multiple source tasks. We need to output a single epoch marker once all
4747
* source tasks have sent one.
4848
*/
49-
private[shuffle] class RPCContinuousShuffleReader(
49+
private[continuous] class RPCContinuousShuffleReader(
5050
queueSize: Int,
5151
numShuffleWriters: Int,
5252
epochIntervalMs: Long,
@@ -107,7 +107,7 @@ private[shuffle] class RPCContinuousShuffleReader(
107107
}
108108
logWarning(
109109
s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " +
110-
s"for writers $writerIdsUncommitted to send epoch markers.")
110+
s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.")
111111

112112
// The completion service guarantees this future will be available immediately.
113113
case future => future.get() match {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ import java.util.concurrent.atomic.AtomicInteger
2323
import javax.annotation.concurrent.GuardedBy
2424

2525
import scala.collection.JavaConverters._
26+
import scala.collection.SortedMap
2627
import scala.collection.mutable.ListBuffer
2728

2829
import org.json4s.NoTypeHints
2930
import org.json4s.jackson.Serialization
3031

31-
import org.apache.spark.SparkEnv
32+
import org.apache.spark.{SparkEnv, TaskContext}
3233
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
3334
import org.apache.spark.sql.{Encoder, Row, SQLContext}
3435
import org.apache.spark.sql.execution.streaming._
@@ -184,6 +185,14 @@ class ContinuousMemoryStreamInputPartitionReader(
184185
private var currentOffset = startOffset
185186
private var current: Option[Row] = None
186187

188+
// Defense-in-depth against failing to propagate the task context. Since it's not inheritable,
189+
// we have to do a bit of error prone work to get it into every thread used by continuous
190+
// processing. We hope that some unit test will end up instantiating a continuous memory stream
191+
// in such cases.
192+
if (TaskContext.get() == null) {
193+
throw new IllegalStateException("Task context was not set!")
194+
}
195+
187196
override def next(): Boolean = {
188197
current = getRecord
189198
while (current.isEmpty) {

0 commit comments

Comments
 (0)