@@ -30,6 +30,9 @@ import org.apache.spark.partial.{BoundedDouble, PartialResult}
3030import org .apache .spark .rdd .RDD
3131import org .apache .spark .scheduler .ExecutorCacheTaskLocation
3232import org .apache .spark .storage .StorageLevel
33+ import org .apache .spark .util .NextIterator
34+
35+ import scala .annotation .tailrec
3336
3437/**
3538 * A batch-oriented interface for consuming from Kafka.
@@ -83,11 +86,11 @@ private[spark] class KafkaRDD[K, V](
8386
8487 override def getPartitions : Array [Partition ] = {
8588 offsetRanges.zipWithIndex.map { case (o, i) =>
86- new KafkaRDDPartition (i, o.topic, o.partition, o.fromOffset, o.untilOffset)
89+ new KafkaRDDPartition (i, o.topic, o.partition, o.fromOffset, o.untilOffset)
8790 }.toArray
8891 }
8992
90- override def count (): Long = offsetRanges.map(_.count).sum
93+ // override def count(): Long = offsetRanges.map(_.count).sum
9194
9295 override def countApprox (
9396 timeout : Long ,
@@ -193,7 +196,7 @@ private[spark] class KafkaRDD[K, V](
193196 */
194197 private class KafkaRDDIterator (
195198 part : KafkaRDDPartition ,
196- context : TaskContext ) extends Iterator [ConsumerRecord [K , V ]] {
199+ context : TaskContext ) extends NextIterator [ConsumerRecord [K , V ]] {
197200
198201 logInfo(s " Computing topic ${part.topic}, partition ${part.partition} " +
199202 s " offsets ${part.fromOffset} -> ${part.untilOffset}" )
@@ -215,19 +218,41 @@ private[spark] class KafkaRDD[K, V](
215218
216219 var requestOffset = part.fromOffset
217220
218- def closeIfNeeded (): Unit = {
221+ override def close (): Unit = {
219222 if (! useConsumerCache && consumer != null ) {
220223 consumer.close
221224 }
222225 }
223226
224- override def hasNext (): Boolean = requestOffset < part.untilOffset
227+ // override def hasNext(): Boolean = requestOffset < part.untilOffset
228+ override def getNext (): ConsumerRecord [K , V ] = {
229+
230+ @ tailrec
231+ def skipGapsAndGetNext : ConsumerRecord [K , V ] = {
232+ if (requestOffset < part.untilOffset) {
233+ val r = consumer.get(requestOffset, pollTimeout)
234+
235+ requestOffset = if (r.offset() == 0 ) {part.untilOffset} else {r.offset() + 1 }
236+
237+ if (null == r && r.offset() == 0 ) {
238+ skipGapsAndGetNext
239+ } else {
240+ r
241+ }
242+ } else {
243+ finished = true
244+ null .asInstanceOf [ConsumerRecord [K , V ]]
245+ }
246+ }
225247
226- override def next (): ConsumerRecord [K , V ] = {
227- assert(hasNext(), " Can't call getNext() once untilOffset has been reached" )
228- val r = consumer.get(requestOffset, pollTimeout)
229- requestOffset += 1
230- r
248+ skipGapsAndGetNext
231249 }
250+
251+ // override def next(): ConsumerRecord[K, V] = {
252+ // assert(hasNext, "Can't call getNext() once untilOffset has been reached")
253+ // val r = consumer.get(requestOffset, pollTimeout)
254+ // requestOffset += 1
255+ // r
256+ // }
232257 }
233258}
0 commit comments