Skip to content

Commit c3768c5

Browse files
committed
[Streaming][Kafka] Take advantage of offset range info for size-related KafkaRDD methods. Possible fix for [SPARK-7122], but probably a worthwhile optimization regardless.
1 parent c6a6dd0 commit c3768c5

File tree

4 files changed

+60
-1
lines changed

4 files changed

+60
-1
lines changed

external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
package org.apache.spark.streaming.kafka
1919

20+
import scala.collection.mutable.ArrayBuffer
2021
import scala.reflect.{classTag, ClassTag}
2122

2223
import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext}
24+
import org.apache.spark.partial.{PartialResult, BoundedDouble}
2325
import org.apache.spark.rdd.RDD
2426
import org.apache.spark.util.NextIterator
2527

@@ -60,6 +62,49 @@ class KafkaRDD[
6062
}.toArray
6163
}
6264

65+
override def count(): Long = offsetRanges.map(_.count).sum
66+
67+
override def countApprox(
68+
timeout: Long,
69+
confidence: Double = 0.95
70+
): PartialResult[BoundedDouble] = {
71+
val c = count
72+
new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
73+
}
74+
75+
override def isEmpty(): Boolean = count == 0L
76+
77+
override def take(num: Int): Array[R] = {
78+
val nonEmpty = this.partitions
79+
.map(_.asInstanceOf[KafkaRDDPartition])
80+
.filter(_.count > 0)
81+
82+
if (num < 1 || nonEmpty.size < 1) {
83+
return new Array[R](0)
84+
}
85+
86+
var remain = num.toLong
87+
// Determine in advance how many messages need to be taken from each partition
88+
val parts = nonEmpty.flatMap { part =>
89+
if (remain > 0) {
90+
val taken = Math.min(remain, part.count)
91+
remain = remain - taken
92+
Some((part.index -> taken.toInt))
93+
} else {
94+
None
95+
}
96+
}.toMap
97+
98+
val buf = new ArrayBuffer[R]
99+
val res = context.runJob(
100+
this,
101+
(tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray,
102+
parts.keys.toArray,
103+
allowLocal = true)
104+
res.foreach(buf ++= _)
105+
buf.toArray
106+
}
107+
63108
override def getPreferredLocations(thePart: Partition): Seq[String] = {
64109
val part = thePart.asInstanceOf[KafkaRDDPartition]
65110
// TODO is additional hostname resolution necessary here

external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,7 @@ class KafkaRDDPartition(
3535
val untilOffset: Long,
3636
val host: String,
3737
val port: Int
38-
) extends Partition
38+
) extends Partition {
39+
/** Number of messages this partition refers to */
40+
def count(): Long = untilOffset - fromOffset
41+
}

external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ final class OffsetRange private(
5555
val untilOffset: Long) extends Serializable {
5656
import OffsetRange.OffsetRangeTuple
5757

58+
/** Number of messages this OffsetRange refers to */
59+
def count(): Long = untilOffset - fromOffset
60+
5861
override def equals(obj: Any): Boolean = obj match {
5962
case that: OffsetRange =>
6063
this.topic == that.topic &&

external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
7070

7171
val received = rdd.map(_._2).collect.toSet
7272
assert(received === messages)
73+
74+
// size-related method optimizations return sane results
75+
assert(rdd.count === messages.size)
76+
assert(rdd.countApprox(0).getFinalValue.mean === messages.size)
77+
assert(! rdd.isEmpty)
78+
assert(rdd.take(1).size === 1)
79+
assert(messages(rdd.take(1).head._2))
80+
assert(rdd.take(messages.size + 10).size === messages.size)
7381
}
7482

7583
test("iterator boundary conditions") {

0 commit comments

Comments
 (0)