Skip to content

[SPARK-9065][Streaming][PySpark] Add MessageHandler for Kafka Python API #7410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,29 @@

package org.apache.spark.streaming.kafka

import java.io.OutputStream
import java.lang.{Integer => JInt, Long => JLong}
import java.util.{List => JList, Map => JMap, Set => JSet}

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import com.google.common.base.Charsets.UTF_8
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder}
import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder}
import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler}

import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.streaming.util.WriteAheadLogUtils
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext}
import org.apache.spark.streaming.api.java._
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
import org.apache.spark.streaming.util.WriteAheadLogUtils
import org.apache.spark.{SparkContext, SparkException}

object KafkaUtils {
/**
Expand Down Expand Up @@ -184,6 +188,27 @@ object KafkaUtils {
}
}

private[kafka] def getFromOffsets(
kc: KafkaCluster,
kafkaParams: Map[String, String],
topics: Set[String]
): Map[TopicAndPartition, Long] = {
val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
val result = for {
topicPartitions <- kc.getPartitions(topics).right
leaderOffsets <- (if (reset == Some("smallest")) {
kc.getEarliestLeaderOffsets(topicPartitions)
} else {
kc.getLatestLeaderOffsets(topicPartitions)
}).right
} yield {
leaderOffsets.map { case (tp, lo) =>
(tp, lo.offset)
}
}
KafkaCluster.checkErrors(result)
}

/**
* Create a RDD from Kafka using offset ranges for each topic and partition.
*
Expand Down Expand Up @@ -246,7 +271,7 @@ object KafkaUtils {
// This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker
leaders.map {
case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port))
}.toMap
}
}
val cleanedHandler = sc.clean(messageHandler)
checkOffsets(kc, offsetRanges)
Expand Down Expand Up @@ -406,23 +431,9 @@ object KafkaUtils {
): InputDStream[(K, V)] = {
val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
val kc = new KafkaCluster(kafkaParams)
val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)

val result = for {
topicPartitions <- kc.getPartitions(topics).right
leaderOffsets <- (if (reset == Some("smallest")) {
kc.getEarliestLeaderOffsets(topicPartitions)
} else {
kc.getLatestLeaderOffsets(topicPartitions)
}).right
} yield {
val fromOffsets = leaderOffsets.map { case (tp, lo) =>
(tp, lo.offset)
}
new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
ssc, kafkaParams, fromOffsets, messageHandler)
}
KafkaCluster.checkErrors(result)
val fromOffsets = getFromOffsets(kc, kafkaParams, topics)
new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
ssc, kafkaParams, fromOffsets, messageHandler)
}

/**
Expand Down Expand Up @@ -550,6 +561,8 @@ object KafkaUtils {
* takes care of known parameters instead of passing them from Python
*/
private[kafka] class KafkaUtilsPythonHelper {
import KafkaUtilsPythonHelper._

def createStream(
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
Expand All @@ -570,82 +583,61 @@ private[kafka] class KafkaUtilsPythonHelper {
jsc: JavaSparkContext,
kafkaParams: JMap[String, String],
offsetRanges: JList[OffsetRange],
leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = {
val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
(Array[Byte], Array[Byte])] {
def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
(t1.key(), t1.message())
}
leaders: JMap[TopicAndPartition, Broker]
): JavaRDD[Array[Byte]] = {
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
new PythonMessageAndMetadata(
mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())

val jrdd = KafkaUtils.createRDD[
KafkaUtils.createRDD[
Array[Byte],
Array[Byte],
DefaultDecoder,
DefaultDecoder,
(Array[Byte], Array[Byte])](
jsc,
classOf[Array[Byte]],
classOf[Array[Byte]],
classOf[DefaultDecoder],
classOf[DefaultDecoder],
classOf[(Array[Byte], Array[Byte])],
kafkaParams,
PythonMessageAndMetadata](
jsc.sc,
Map(kafkaParams.asScala.toSeq: _*),
offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
leaders,
messageHandler
)
new JavaPairRDD(jrdd.rdd)
Map(leaders.asScala.toSeq: _*),
messageHandler).mapPartitions { iter => picklerIterator(iter) }
}

def createDirectStream(
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
topics: JSet[String],
fromOffsets: JMap[TopicAndPartition, JLong]
): JavaPairInputDStream[Array[Byte], Array[Byte]] = {
): JavaDStream[Array[Byte]] = {

if (!fromOffsets.isEmpty) {
val currentFromOffsets = if (!fromOffsets.isEmpty) {
val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic)
if (topicsFromOffsets != topics.asScala.toSet) {
throw new IllegalStateException(
s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " +
s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}")
}
Map(fromOffsets.asScala.mapValues { _.longValue() }.toSeq: _*)
} else {
val kc = new KafkaCluster(Map(kafkaParams.asScala.toSeq: _*))
KafkaUtils.getFromOffsets(
kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*))
}

if (fromOffsets.isEmpty) {
KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder](
jssc,
classOf[Array[Byte]],
classOf[Array[Byte]],
classOf[DefaultDecoder],
classOf[DefaultDecoder],
kafkaParams,
topics)
} else {
val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
(Array[Byte], Array[Byte])] {
def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
(t1.key(), t1.message())
}
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
new PythonMessageAndMetadata(
mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())

val jstream = KafkaUtils.createDirectStream[
Array[Byte],
Array[Byte],
DefaultDecoder,
DefaultDecoder,
(Array[Byte], Array[Byte])](
jssc,
classOf[Array[Byte]],
classOf[Array[Byte]],
classOf[DefaultDecoder],
classOf[DefaultDecoder],
classOf[(Array[Byte], Array[Byte])],
kafkaParams,
fromOffsets,
messageHandler)
new JavaPairInputDStream(jstream.inputDStream)
}
val stream = KafkaUtils.createDirectStream[
Array[Byte],
Array[Byte],
DefaultDecoder,
DefaultDecoder,
PythonMessageAndMetadata](
jssc.ssc,
Map(kafkaParams.asScala.toSeq: _*),
Map(currentFromOffsets.toSeq: _*),
messageHandler).mapPartitions { iter => picklerIterator(iter) }
new JavaDStream(stream)
}

def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong
Expand All @@ -669,3 +661,57 @@ private[kafka] class KafkaUtilsPythonHelper {
kafkaRDD.offsetRanges.toSeq.asJava
}
}

private object KafkaUtilsPythonHelper {
private var initialized = false

def initialize(): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not directly call SerDeUtil.initialize() and new PythonMessageAndMetadataPickler().register() in the static codes of KafkaUtilsPythonHelper? It would be much simpler since it's only two lines.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. You want to call initialize to load the object KafkaUtilsPythonHelper in the closure. If so, how about adding a method in KafkaUtilsPythonHelper, such as

object KafkaUtilsPythonHelper {
...
def createPicklerIterator: Iterator[Array[Byte]] = {
  new SerDeUtil.AutoBatchedPickler(iter)
}
...
}

and call it in the closure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please share more details on it, I cannot exactly get what you mean. What I did here is to follow the pattern here in MLlib(https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala#L1370)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since initialize() is already called in object KafkaUtilsPythonHelper, the purpose of adding the initialize method is only for loading the class of object KafkaUtilsPythonHelper, right?

If changing it to

object KafkaUtilsPythonHelper {

  SerDeUtil.initialize()
  new PythonMessageAndMetadataPickler().register()

...
  def createPicklerIterator: Iterator[Array[Byte]] = {
    new SerDeUtil.AutoBatchedPickler(iter)
  }
...
}

When calling KafkaUtilsPythonHelper.createPicklerIterator in the closure, SerDeUtil.initialize() and new PythonMessageAndMetadataPickler().register() will be called when loading class of object KafkaUtilsPythonHelper.

SerDeUtil.initialize()
synchronized {
if (!initialized) {
new PythonMessageAndMetadataPickler().register()
initialized = true
}
}
}

initialize()

def picklerIterator(iter: Iterator[Any]): Iterator[Array[Byte]] = {
new SerDeUtil.AutoBatchedPickler(iter)
}

case class PythonMessageAndMetadata(
topic: String,
partition: JInt,
offset: JLong,
key: Array[Byte],
message: Array[Byte])

class PythonMessageAndMetadataPickler extends IObjectPickler {
private val module = "pyspark.streaming.kafka"

def register(): Unit = {
Pickler.registerCustomPickler(classOf[PythonMessageAndMetadata], this)
Pickler.registerCustomPickler(this.getClass, this)
}

def pickle(obj: Object, out: OutputStream, pickler: Pickler) {
if (obj == this) {
out.write(Opcodes.GLOBAL)
out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes(UTF_8))
} else {
pickler.save(this)
val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata]
out.write(Opcodes.MARK)
pickler.save(msgAndMetaData.topic)
pickler.save(msgAndMetaData.partition)
pickler.save(msgAndMetaData.offset)
pickler.save(msgAndMetaData.key)
pickler.save(msgAndMetaData.message)
out.write(Opcodes.TUPLE)
out.write(Opcodes.REDUCE)
}
}
}
}
6 changes: 6 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$")
) ++ Seq(
// SPARK-9065 Support message handler in Kafka Python API
ProblemFilters.exclude[IncompatibleResultTypeProblem](
"org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"),
ProblemFilters.exclude[IncompatibleResultTypeProblem](
"org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD")
)
case v if v.startsWith("1.5") =>
Seq(
Expand Down
Loading