Skip to content

Commit 761565a

Browse files
committed
Revert "[SPARK-23096][SS] Migrate rate source to V2"
This reverts commit c68ec4e.
1 parent 34c4b9c commit 761565a

File tree

10 files changed

+844
-715
lines changed

10 files changed

+844
-715
lines changed

sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
55
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
66
org.apache.spark.sql.execution.datasources.text.TextFileFormat
77
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
8-
org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
8+
org.apache.spark.sql.execution.streaming.RateSourceProvider
99
org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
10+
org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
4141
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
4242
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
4343
import org.apache.spark.sql.execution.streaming._
44-
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
44+
import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
4545
import org.apache.spark.sql.internal.SQLConf
4646
import org.apache.spark.sql.sources._
4747
import org.apache.spark.sql.streaming.OutputMode
@@ -566,7 +566,6 @@ object DataSource extends Logging {
566566
val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
567567
val nativeOrc = classOf[OrcFileFormat].getCanonicalName
568568
val socket = classOf[TextSocketSourceProvider].getCanonicalName
569-
val rate = classOf[RateStreamProvider].getCanonicalName
570569

571570
Map(
572571
"org.apache.spark.sql.jdbc" -> jdbc,
@@ -588,8 +587,7 @@ object DataSource extends Logging {
588587
"org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
589588
"org.apache.spark.ml.source.libsvm" -> libsvm,
590589
"com.databricks.spark.csv" -> csv,
591-
"org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket,
592-
"org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate
590+
"org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket
593591
)
594592
}
595593

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming
19+
20+
import java.io._
21+
import java.nio.charset.StandardCharsets
22+
import java.util.Optional
23+
import java.util.concurrent.TimeUnit
24+
25+
import org.apache.commons.io.IOUtils
26+
27+
import org.apache.spark.internal.Logging
28+
import org.apache.spark.network.util.JavaUtils
29+
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
30+
import org.apache.spark.sql.catalyst.InternalRow
31+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
32+
import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
33+
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
34+
import org.apache.spark.sql.sources.v2._
35+
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
36+
import org.apache.spark.sql.types._
37+
import org.apache.spark.util.{ManualClock, SystemClock}
38+
39+
/**
40+
* A source that generates increment long values with timestamps. Each generated row has two
41+
* columns: a timestamp column for the generated time and an auto increment long column starting
42+
* with 0L.
43+
*
44+
* This source supports the following options:
45+
* - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
46+
* - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
47+
* becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
48+
* seconds.
49+
* - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
50+
* generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
51+
* be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
52+
*/
53+
class RateSourceProvider extends StreamSourceProvider with DataSourceRegister
54+
with DataSourceV2 with ContinuousReadSupport {
55+
56+
override def sourceSchema(
57+
sqlContext: SQLContext,
58+
schema: Option[StructType],
59+
providerName: String,
60+
parameters: Map[String, String]): (String, StructType) = {
61+
if (schema.nonEmpty) {
62+
throw new AnalysisException("The rate source does not support a user-specified schema.")
63+
}
64+
65+
(shortName(), RateSourceProvider.SCHEMA)
66+
}
67+
68+
override def createSource(
69+
sqlContext: SQLContext,
70+
metadataPath: String,
71+
schema: Option[StructType],
72+
providerName: String,
73+
parameters: Map[String, String]): Source = {
74+
val params = CaseInsensitiveMap(parameters)
75+
76+
val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L)
77+
if (rowsPerSecond <= 0) {
78+
throw new IllegalArgumentException(
79+
s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " +
80+
"must be positive")
81+
}
82+
83+
val rampUpTimeSeconds =
84+
params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L)
85+
if (rampUpTimeSeconds < 0) {
86+
throw new IllegalArgumentException(
87+
s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " +
88+
"must not be negative")
89+
}
90+
91+
val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse(
92+
sqlContext.sparkContext.defaultParallelism)
93+
if (numPartitions <= 0) {
94+
throw new IllegalArgumentException(
95+
s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " +
96+
"must be positive")
97+
}
98+
99+
new RateStreamSource(
100+
sqlContext,
101+
metadataPath,
102+
rowsPerSecond,
103+
rampUpTimeSeconds,
104+
numPartitions,
105+
params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing
106+
)
107+
}
108+
109+
override def createContinuousReader(
110+
schema: Optional[StructType],
111+
checkpointLocation: String,
112+
options: DataSourceOptions): ContinuousReader = {
113+
new RateStreamContinuousReader(options)
114+
}
115+
116+
override def shortName(): String = "rate"
117+
}
118+
119+
object RateSourceProvider {
120+
val SCHEMA =
121+
StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
122+
123+
val VERSION = 1
124+
}
125+
126+
class RateStreamSource(
127+
sqlContext: SQLContext,
128+
metadataPath: String,
129+
rowsPerSecond: Long,
130+
rampUpTimeSeconds: Long,
131+
numPartitions: Int,
132+
useManualClock: Boolean) extends Source with Logging {
133+
134+
import RateSourceProvider._
135+
import RateStreamSource._
136+
137+
val clock = if (useManualClock) new ManualClock else new SystemClock
138+
139+
private val maxSeconds = Long.MaxValue / rowsPerSecond
140+
141+
if (rampUpTimeSeconds > maxSeconds) {
142+
throw new ArithmeticException(
143+
s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
144+
s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
145+
}
146+
147+
private val startTimeMs = {
148+
val metadataLog =
149+
new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) {
150+
override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
151+
val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
152+
writer.write("v" + VERSION + "\n")
153+
writer.write(metadata.json)
154+
writer.flush
155+
}
156+
157+
override def deserialize(in: InputStream): LongOffset = {
158+
val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
159+
// HDFSMetadataLog guarantees that it never creates a partial file.
160+
assert(content.length != 0)
161+
if (content(0) == 'v') {
162+
val indexOfNewLine = content.indexOf("\n")
163+
if (indexOfNewLine > 0) {
164+
val version = parseVersion(content.substring(0, indexOfNewLine), VERSION)
165+
LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
166+
} else {
167+
throw new IllegalStateException(
168+
s"Log file was malformed: failed to detect the log file version line.")
169+
}
170+
} else {
171+
throw new IllegalStateException(
172+
s"Log file was malformed: failed to detect the log file version line.")
173+
}
174+
}
175+
}
176+
177+
metadataLog.get(0).getOrElse {
178+
val offset = LongOffset(clock.getTimeMillis())
179+
metadataLog.add(0, offset)
180+
logInfo(s"Start time: $offset")
181+
offset
182+
}.offset
183+
}
184+
185+
/** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */
186+
@volatile private var lastTimeMs = startTimeMs
187+
188+
override def schema: StructType = RateSourceProvider.SCHEMA
189+
190+
override def getOffset: Option[Offset] = {
191+
val now = clock.getTimeMillis()
192+
if (lastTimeMs < now) {
193+
lastTimeMs = now
194+
}
195+
Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs)))
196+
}
197+
198+
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
199+
val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L)
200+
val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
201+
assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
202+
if (endSeconds > maxSeconds) {
203+
throw new ArithmeticException("Integer overflow. Max offset with " +
204+
s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
205+
}
206+
// Fix "lastTimeMs" for recovery
207+
if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
208+
lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs
209+
}
210+
val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
211+
val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
212+
logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
213+
s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
214+
215+
if (rangeStart == rangeEnd) {
216+
return sqlContext.internalCreateDataFrame(
217+
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
218+
}
219+
220+
val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
221+
val relativeMsPerValue =
222+
TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
223+
224+
val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
225+
val relative = math.round((v - rangeStart) * relativeMsPerValue)
226+
InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
227+
}
228+
sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
229+
}
230+
231+
override def stop(): Unit = {}
232+
233+
override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " +
234+
s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]"
235+
}
236+
237+
object RateStreamSource {
238+
239+
/** Calculate the end value we will emit at the time `seconds`. */
240+
def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
241+
// E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
242+
// Then speedDeltaPerSecond = 2
243+
//
244+
// seconds = 0 1 2 3 4 5 6
245+
// speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
246+
// end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
247+
val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
248+
if (seconds <= rampUpTimeSeconds) {
249+
// Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
250+
// avoid overflow
251+
if (seconds % 2 == 1) {
252+
(seconds + 1) / 2 * speedDeltaPerSecond * seconds
253+
} else {
254+
seconds / 2 * speedDeltaPerSecond * (seconds + 1)
255+
}
256+
} else {
257+
// rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
258+
val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
259+
rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
260+
}
261+
}
262+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization
2424

2525
import org.apache.spark.sql.Row
2626
import org.apache.spark.sql.catalyst.util.DateTimeUtils
27-
import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
28-
import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
27+
import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair}
28+
import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2
2929
import org.apache.spark.sql.sources.v2.DataSourceOptions
3030
import org.apache.spark.sql.sources.v2.reader._
3131
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
@@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions)
4040

4141
val creationTime = System.currentTimeMillis()
4242

43-
val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt
44-
val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong
43+
val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
44+
val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
4545
val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble
4646

4747
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
@@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions)
5757
RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
5858
}
5959

60-
override def readSchema(): StructType = RateStreamProvider.SCHEMA
60+
override def readSchema(): StructType = RateSourceProvider.SCHEMA
6161

6262
private var offset: Offset = _
6363

6464
override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
65-
this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime))
65+
this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime))
6666
}
6767

6868
override def getStartOffset(): Offset = offset
@@ -98,19 +98,6 @@ class RateStreamContinuousReader(options: DataSourceOptions)
9898
override def commit(end: Offset): Unit = {}
9999
override def stop(): Unit = {}
100100

101-
private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
102-
RateStreamOffset(
103-
Range(0, numPartitions).map { i =>
104-
// Note that the starting offset is exclusive, so we have to decrement the starting value
105-
// by the increment that will later be applied. The first row output in each
106-
// partition will have a value equal to the partition index.
107-
(i,
108-
ValueRunTimeMsPair(
109-
(i - numPartitions).toLong,
110-
creationTimeMs))
111-
}.toMap)
112-
}
113-
114101
}
115102

116103
case class RateStreamContinuousDataReaderFactory(

0 commit comments

Comments
 (0)