Skip to content

Commit 261a051

Browse files
committed
- removed field to hold the current rate limit in rate limiter
- made rate limit a Long and default to Long.MaxValue (consequence of the above) - removed custom `waitUntil` and replaced it by `eventually`
1 parent cd1397d commit 261a051

File tree

8 files changed

+36
-46
lines changed

8 files changed

+36
-46
lines changed

streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,25 @@ import org.apache.spark.{Logging, SparkConf}
3737
private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging {
3838

3939
// treated as an upper limit
40-
private val maxRateLimit = conf.getInt("spark.streaming.receiver.maxRate", 0)
41-
private[receiver] var currentRateLimit = new AtomicInteger(maxRateLimit)
42-
private lazy val rateLimiter = GuavaRateLimiter.create(currentRateLimit.get())
40+
private val maxRateLimit = conf.getLong("spark.streaming.receiver.maxRate", Long.MaxValue)
41+
private lazy val rateLimiter = GuavaRateLimiter.create(maxRateLimit.toDouble)
4342

4443
def waitToPush() {
45-
if (currentRateLimit.get() > 0) {
46-
rateLimiter.acquire()
47-
}
44+
rateLimiter.acquire()
4845
}
4946

50-
private[receiver] def updateRate(newRate: Int): Unit =
47+
/**
48+
* Return the current rate limit. If no limit has been set so far, it returns {{{Long.MaxValue}}}.
49+
*/
50+
def getCurrentLimit: Long =
51+
rateLimiter.getRate.toLong
52+
53+
private[receiver] def updateRate(newRate: Long): Unit =
5154
if (newRate > 0) {
52-
try {
53-
if (maxRateLimit > 0) {
54-
currentRateLimit.set(newRate.min(maxRateLimit))
55-
}
56-
else {
57-
currentRateLimit.set(newRate)
58-
}
59-
} finally {
60-
rateLimiter.setRate(currentRateLimit.get())
55+
if (maxRateLimit > 0) {
56+
rateLimiter.setRate(newRate.min(maxRateLimit))
57+
} else {
58+
rateLimiter.setRate(newRate)
6159
}
6260
}
6361
}

streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
271271
}
272272

273273
/** Get the attached executor. */
274-
private[streaming] def executor = {
274+
private[streaming] def executor: ReceiverSupervisor = {
275275
assert(executor_ != null, "Executor has not been attached to this receiver")
276276
executor_
277277
}

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ private[streaming] abstract class ReceiverSupervisor(
5959
private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000)
6060

6161
/** The current maximum rate limit for this receiver. */
62-
private[streaming] def getCurrentRateLimit: Option[Int] = None
62+
private[streaming] def getCurrentRateLimit: Option[Long] = None
6363

6464
/** Exception associated with the stopping of the receiver */
6565
@volatile protected var stoppingError: Throwable = null

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ private[streaming] class ReceiverSupervisorImpl(
7878
logDebug("Received delete old batch signal")
7979
cleanupOldBlocks(threshTime)
8080
case UpdateRateLimit(eps) =>
81-
blockGenerator.updateRate(eps.toInt)
81+
logInfo(s"Received a new rate limit: $eps.")
82+
blockGenerator.updateRate(eps)
8283
}
8384
})
8485

@@ -100,8 +101,8 @@ private[streaming] class ReceiverSupervisorImpl(
100101
}
101102
}, streamId, env.conf)
102103

103-
override private[streaming] def getCurrentRateLimit: Option[Int] =
104-
Some(blockGenerator.currentRateLimit.get)
104+
override private[streaming] def getCurrentRateLimit: Option[Long] =
105+
Some(blockGenerator.getCurrentLimit)
105106

106107
/** Push a single record of received data into block generator. */
107108
def pushSingle(data: Any) {

streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
180180
logError(s"Deregistered receiver for stream $streamId: $messageWithError")
181181
}
182182

183-
/** Update a receiver's maximum rate from an estimator's update */
183+
/** Update a receiver's maximum ingestion rate */
184184
def sendRateUpdate(streamUID: Int, newRate: Long): Unit = {
185185
for (info <- receiverInfo.get(streamUID); eP <- Option(info.endpoint))
186186
eP.send(UpdateRateLimit(newRate))

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -537,19 +537,4 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
537537
verifyOutput[W](output, expectedOutput, useSet)
538538
}
539539
}
540-
541-
/**
542-
* Wait until `cond` becomes true, or timeout ms have passed. This method checks the condition
543-
* every 100ms, so it won't wait more than 100ms more than necessary.
544-
*
545-
* @param cond A boolean that should become `true`
546-
* @param timemout How many millis to wait before giving up
547-
*/
548-
def waitUntil(cond: => Boolean, timeout: Int): Unit = {
549-
val start = System.currentTimeMillis()
550-
val end = start + timeout
551-
while ((System.currentTimeMillis() < end) && !cond) {
552-
Thread.sleep(100)
553-
}
554-
}
555540
}

streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,20 @@ class RateLimiterSuite extends SparkFunSuite {
2727
val conf = new SparkConf()
2828
val rateLimiter = new RateLimiter(conf){}
2929
rateLimiter.updateRate(105)
30-
assert(rateLimiter.currentRateLimit.get == 105)
30+
assert(rateLimiter.getCurrentLimit == 105)
3131
}
3232

3333
test("rate limiter updates when below maxRate") {
3434
val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "110")
3535
val rateLimiter = new RateLimiter(conf){}
3636
rateLimiter.updateRate(105)
37-
assert(rateLimiter.currentRateLimit.get == 105)
37+
assert(rateLimiter.getCurrentLimit == 105)
3838
}
3939

4040
test("rate limiter stays below maxRate despite large updates") {
4141
val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "100")
4242
val rateLimiter = new RateLimiter(conf){}
4343
rateLimiter.updateRate(105)
44-
assert(rateLimiter.currentRateLimit.get == 100)
44+
assert(rateLimiter.getCurrentLimit === 100)
4545
}
46-
4746
}

streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.streaming.scheduler
1919

20+
import org.scalatest.concurrent.Eventually._
21+
import org.scalatest.concurrent.Timeouts
22+
import org.scalatest.time.SpanSugar._
2023
import org.apache.spark.streaming._
2124
import org.apache.spark.SparkConf
2225
import org.apache.spark.storage.StorageLevel
@@ -77,15 +80,18 @@ class ReceiverTrackerSuite extends TestSuiteBase {
7780
}
7881

7982
test("Receiver tracker - propagates rate limit") {
80-
val newRateLimit = 100
83+
val newRateLimit = 100L
8184
val ids = new TestReceiverInputDStream(ssc)
8285
val tracker = new ReceiverTracker(ssc)
8386
tracker.start()
84-
waitUntil(TestDummyReceiver.started, 5000)
87+
eventually(timeout(5 seconds)) {
88+
assert(TestDummyReceiver.started)
89+
}
8590
tracker.sendRateUpdate(ids.id, newRateLimit)
8691
// this is an async message, we need to wait a bit for it to be processed
87-
waitUntil(ids.getRateLimit.get == newRateLimit, 1000)
88-
assert(ids.getRateLimit.get === newRateLimit)
92+
eventually(timeout(3 seconds)) {
93+
assert(ids.getCurrentRateLimit.get === newRateLimit)
94+
}
8995
}
9096
}
9197

@@ -95,8 +101,9 @@ private class TestReceiverInputDStream(@transient ssc_ : StreamingContext)
95101

96102
override def getReceiver(): DummyReceiver = TestDummyReceiver
97103

98-
def getRateLimit: Option[Int] =
104+
def getCurrentRateLimit: Option[Long] = {
99105
TestDummyReceiver.executor.getCurrentRateLimit
106+
}
100107
}
101108

102109
/**

0 commit comments

Comments
 (0)