Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.streaming.receiver

import org.apache.spark.{Logging, SparkConf}
import java.util.concurrent.TimeUnit._
import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter}

/** Provides waitToPush() method to limit the rate at which receivers consume data.
*
Expand All @@ -33,37 +33,12 @@ import java.util.concurrent.TimeUnit._
*/
private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging {

private var lastSyncTime = System.nanoTime
private var messagesWrittenSinceSync = 0L
private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0)
private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS)
private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate)

def waitToPush() {
if( desiredRate <= 0 ) {
return
}
val now = System.nanoTime
val elapsedNanosecs = math.max(now - lastSyncTime, 1)
val rate = messagesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs
if (rate < desiredRate) {
// It's okay to write; just update some variables and return
messagesWrittenSinceSync += 1
if (now > lastSyncTime + SYNC_INTERVAL) {
// Sync interval has passed; let's resync
lastSyncTime = now
messagesWrittenSinceSync = 1
}
} else {
// Calculate how much time we should sleep to bring ourselves to the desired rate.
val targetTimeInMillis = messagesWrittenSinceSync * 1000 / desiredRate
val elapsedTimeInMillis = elapsedNanosecs / 1000000
val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis
if (sleepTimeInMillis > 0) {
logTrace("Natural rate is " + rate + " per second but desired rate is " +
desiredRate + ", sleeping for " + sleepTimeInMillis + " ms to compensate.")
Thread.sleep(sleepTimeInMillis)
}
waitToPush()
if (desiredRate > 0) {
rateLimiter.acquire()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
test("block generator throttling") {
val blockGeneratorListener = new FakeBlockGeneratorListener
val blockIntervalMs = 100
val maxRate = 100
val maxRate = 1001
val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms").
set("spark.streaming.receiver.maxRate", maxRate.toString)
val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf)
Expand All @@ -176,7 +176,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
blockGenerator.addData(count)
generatedData += count
count += 1
Thread.sleep(1)
}
blockGenerator.stop()

Expand All @@ -185,25 +184,31 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
assert(blockGeneratorListener.arrayBuffers.size > 0, "No blocks received")
assert(recordedData.toSet === generatedData.toSet, "Received data not same")

// recordedData size should be close to the expected rate
val minExpectedMessages = expectedMessages - 3
val maxExpectedMessages = expectedMessages + 1
// recordedData size should be close to the expected rate; use an error margin proportional to
// the value, so that rate changes don't cause a brittle test
val minExpectedMessages = expectedMessages - 0.05 * expectedMessages
val maxExpectedMessages = expectedMessages + 0.05 * expectedMessages
val numMessages = recordedData.size
assert(
numMessages >= minExpectedMessages && numMessages <= maxExpectedMessages,
s"#records received = $numMessages, not between $minExpectedMessages and $maxExpectedMessages"
)

val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3
val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1
// XXX Checking every block would require an even distribution of messages across blocks,
// which throttling code does not control. Therefore, test against the average.
val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 0.05 * expectedMessagesPerBlock
val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 0.05 * expectedMessagesPerBlock
val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",")

// the first and last block may be incomplete, so we slice them out
val validBlocks = recordedBlocks.drop(1).dropRight(1)
val averageBlockSize = validBlocks.map(block => block.size).sum / validBlocks.size

assert(
// the first and last block may be incomplete, so we slice them out
recordedBlocks.drop(1).dropRight(1).forall { block =>
block.size >= minExpectedMessagesPerBlock && block.size <= maxExpectedMessagesPerBlock
},
averageBlockSize >= minExpectedMessagesPerBlock &&
averageBlockSize <= maxExpectedMessagesPerBlock,
s"# records in received blocks = [$receivedBlockSizes], not between " +
s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock"
s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock, on average"
)
}

Expand Down