Skip to content

Commit aa5b097

Browse files
committed
Add more comments, made all PID constant parameters positive, a couple more tests.
1 parent 93b74f8 commit aa5b097

File tree

3 files changed

+81
-30
lines changed

3 files changed

+81
-30
lines changed

streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,22 @@ package org.apache.spark.streaming.scheduler.rate
2828
*
2929
* @param batchDurationMillis the batch duration, in milliseconds
3030
* @param proportional how much the correction should depend on the current
31-
* error. This term usually provides the bulk of correction. A value too large would
32-
* make the controller overshoot the setpoint, while a small value would make the
33-
* controller too insensitive. The default value is -1.
31+
* error. This term usually provides the bulk of correction and should be positive or zero.
32+
* A value too large would make the controller overshoot the setpoint, while a small value
33+
* would make the controller too insensitive. The default value is 1.
3434
* @param integral how much the correction should depend on the accumulation
35-
* of past errors. This term accelerates the movement towards the setpoint, but a large
36-
* value may lead to overshooting. The default value is -0.2.
35+
* of past errors. This value should be positive or 0. This term accelerates the movement
36+
* towards the desired value, but a large value may lead to overshooting. The default value
37+
* is 0.2.
3738
* @param derivative how much the correction should depend on a prediction
38-
* of future errors, based on current rate of change. This term is not used very often,
39-
* as it impacts stability of the system. The default value is 0.
39+
* of future errors, based on current rate of change. This value should be positive or 0.
40+
* This term is not used very often, as it impacts stability of the system. The default
41+
* value is 0.
4042
*/
4143
private[streaming] class PIDRateEstimator(
4244
batchIntervalMillis: Long,
43-
proportional: Double = -1D,
44-
integral: Double = -.2D,
45+
proportional: Double = 1D,
46+
integral: Double = .2D,
4547
derivative: Double = 0D)
4648
extends RateEstimator {
4749

@@ -53,9 +55,19 @@ private[streaming] class PIDRateEstimator(
5355
require(
5456
batchIntervalMillis > 0,
5557
s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.")
58+
require(
59+
proportional >= 0,
60+
s"Proportional term $proportional in PIDRateEstimator should be >= 0.")
61+
require(
62+
integral >= 0,
63+
s"Integral term $integral in PIDRateEstimator should be >= 0.")
64+
require(
65+
derivative >= 0,
66+
s"Derivative term $derivative in PIDRateEstimator should be >= 0.")
67+
5668

5769
def compute(time: Long, // in milliseconds
58-
elements: Long,
70+
numElements: Long,
5971
processingDelay: Long, // in milliseconds
6072
schedulingDelay: Long // in milliseconds
6173
): Option[Double] = {
@@ -67,16 +79,19 @@ private[streaming] class PIDRateEstimator(
6779
val delaySinceUpdate = (time - latestTime).toDouble / 1000
6880

6981
// in elements/second
70-
val processingRate = elements.toDouble / processingDelay * 1000
82+
val processingRate = numElements.toDouble / processingDelay * 1000
7183

84+
// In our system `error` is the difference between the desired rate and the measured rate
85+
// based on the latest batch information. We consider the desired rate to be latest rate,
86+
// which is what this estimator calculated for the previous batch.
7287
// in elements/second
7388
val error = latestRate - processingRate
7489

75-
// The error integral, based on schedulingDelay as an indicator for accumulated errors
76-
// a scheduling delay s corresponds to s * processingRate overflowing elements. Those
90+
// The error integral, based on schedulingDelay as an indicator for accumulated errors.
91+
// A scheduling delay s corresponds to s * processingRate overflowing elements. Those
7792
// are elements that couldn't be processed in previous batches, leading to this delay.
78-
// We assume the processingRate didn't change too much.
79-
// from the number of overflowing elements we can calculate the rate at which they would be
93+
// In the following, we assume the processingRate didn't change too much.
94+
// From the number of overflowing elements we can calculate the rate at which they would be
8095
// processed by dividing it by the batch interval. This rate is our "historical" error,
8196
// or integral part, since if we subtracted this rate from the previous "calculated rate",
8297
// there wouldn't have been any overflowing elements, and the scheduling delay would have
@@ -87,8 +102,8 @@ private[streaming] class PIDRateEstimator(
87102
// in elements/(second ^ 2)
88103
val dError = (error - latestError) / delaySinceUpdate
89104

90-
val newRate = (latestRate + proportional * error +
91-
integral * historicalError +
105+
val newRate = (latestRate - proportional * error -
106+
integral * historicalError -
92107
derivative * dError).max(0.0)
93108
latestTime = time
94109
if (firstRun) {

streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ object RateEstimator {
5858
def create(conf: SparkConf, batchInterval: Duration): RateEstimator =
5959
conf.get("spark.streaming.backpressure.rateEstimator", "pid") match {
6060
case "pid" =>
61-
val proportional = conf.getDouble("spark.streraming.backpressure.pid.proportional", -1.0)
62-
val integral = conf.getDouble("spark.streraming.backpressure.pid.integral", -0.2)
63-
val derived = conf.getDouble("spark.streraming.backpressure.pid.derived", 0.0)
61+
val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0)
62+
val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2)
63+
val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0)
6464
new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived)
6565

6666
case estimator =>

streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,64 @@ package org.apache.spark.streaming.scheduler.rate
1919

2020
import scala.util.Random
2121

22-
import org.scalatest._
22+
import org.scalatest.Inspectors.forAll
2323
import org.scalatest.Matchers
24-
import org.scalatest.Inspectors._
2524

26-
import org.apache.spark.SparkFunSuite
25+
import org.apache.spark.{SparkConf, SparkFunSuite}
26+
import org.apache.spark.streaming.Seconds
2727

2828
class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {
2929

30+
test("the right estimator is created") {
31+
val conf = new SparkConf
32+
conf.set("spark.streaming.backpressure.rateEstimator", "pid")
33+
val pid = RateEstimator.create(conf, Seconds(1))
34+
pid.getClass should equal(classOf[PIDRateEstimator])
35+
}
36+
37+
test("estimator checks ranges") {
38+
intercept[IllegalArgumentException] {
39+
new PIDRateEstimator(0, 1, 2, 3)
40+
}
41+
intercept[IllegalArgumentException] {
42+
new PIDRateEstimator(100, -1, 2, 3)
43+
}
44+
intercept[IllegalArgumentException] {
45+
new PIDRateEstimator(100, 0, -1, 3)
46+
}
47+
intercept[IllegalArgumentException] {
48+
new PIDRateEstimator(100, 0, 0, -1)
49+
}
50+
}
51+
52+
private def createDefaultEstimator: PIDRateEstimator = {
53+
new PIDRateEstimator(20, 1D, 0D, 0D)
54+
}
55+
3056
test("first bound is None") {
31-
val p = new PIDRateEstimator(20, -1D, 0D, 0D)
57+
val p = createDefaultEstimator
3258
p.compute(0, 10, 10, 0) should equal(None)
3359
}
3460

3561
test("second bound is rate") {
36-
val p = new PIDRateEstimator(20, -1D, 0D, 0D)
62+
val p = createDefaultEstimator
3763
p.compute(0, 10, 10, 0)
3864
// 1000 elements / s
3965
p.compute(10, 10, 10, 0) should equal(Some(1000))
4066
}
4167

4268
test("works even with no time between updates") {
43-
val p = new PIDRateEstimator(20, -1D, 0D, 0D)
69+
val p = createDefaultEstimator
4470
p.compute(0, 10, 10, 0)
4571
p.compute(10, 10, 10, 0)
4672
p.compute(10, 10, 10, 0) should equal(None)
4773
}
4874

4975
test("bound is never negative") {
50-
val p = new PIDRateEstimator(20, -1D, -1D, 0D)
76+
val p = new PIDRateEstimator(20, 1D, 1D, 0D)
77+
// prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing
78+
// this might point the estimator to try and decrease the bound, but we test it never
79+
// goes below zero, which would be nonsensical.
5180
val times = List.tabulate(50)(x => x * 20) // every 20ms
5281
val elements = List.fill(50)(0) // no processing
5382
val proc = List.fill(50)(20) // 20ms of processing
@@ -58,7 +87,10 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {
5887
}
5988

6089
test("with no accumulated or positive error, |I| > 0, follow the processing speed") {
61-
val p = new PIDRateEstimator(20, -1D, -1D, 0D)
90+
val p = new PIDRateEstimator(20, 1D, 1D, 0D)
91+
// prepare a series of batch updates, one every 20ms with an increasing number of processed
92+
// elements in each batch, but constant processing time, and no accumulated error. Even though
93+
// the integral part is non-zero, the estimated rate should follow only the proportional term
6294
val times = List.tabulate(50)(x => x * 20) // every 20ms
6395
val elements = List.tabulate(50)(x => x * 20) // increasing
6496
val proc = List.fill(50)(20) // 20ms of processing
@@ -69,7 +101,11 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {
69101
}
70102

71103
test("with no accumulated but some positive error, |I| > 0, follow the processing speed") {
72-
val p = new PIDRateEstimator(20, -1D, -1D, 0D)
104+
val p = new PIDRateEstimator(20, 1D, 1D, 0D)
105+
// prepare a series of batch updates, one every 20ms with an decreasing number of processed
106+
// elements in each batch, but constant processing time, and no accumulated error. Even though
107+
// the integral part is non-zero, the estimated rate should follow only the proportional term,
108+
// asking for less and less elements
73109
val times = List.tabulate(50)(x => x * 20) // every 20ms
74110
val elements = List.tabulate(50)(x => (50 - x) * 20) // decreasing
75111
val proc = List.fill(50)(20) // 20ms of processing
@@ -80,7 +116,7 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {
80116
}
81117

82118
test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") {
83-
val p = new PIDRateEstimator(20, -1D, -.01D, 0D)
119+
val p = new PIDRateEstimator(20, 1D, .01D, 0D)
84120
val times = List.tabulate(50)(x => x * 20) // every 20ms
85121
val rng = new Random()
86122
val elements = List.tabulate(50)(x => rng.nextInt(1000))

0 commit comments

Comments
 (0)