Skip to content

[SPARK-23539][SS] Add support for Kafka headers in Structured Streaming #22282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions docs/structured-streaming-kafka-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli
artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}
version = {{site.SPARK_VERSION_SHORT}}

Please note that to use the headers functionality, your Kafka client version should be version 0.11.0.0 or up.

For Python applications, you need to add this above library and its dependencies when deploying your
application. See the [Deploying](#deploying) subsection below.

Expand All @@ -50,6 +52,17 @@ val df = spark
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]

// Subscribe to 1 topic, with headers
val df = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", "host1:port1,host2:port2")
.option("subscribe", "topic1")
.option("includeHeaders", "true")
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")
.as[(String, String, Map)]

// Subscribe to multiple topics
val df = spark
.readStream
Expand Down Expand Up @@ -84,6 +97,16 @@ Dataset<Row> df = spark
.load();
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)");

// Subscribe to 1 topic, with headers
Dataset<Row> df = spark
.readStream()
.format("kafka")
.option("kafka.bootstrap.servers", "host1:port1,host2:port2")
.option("subscribe", "topic1")
.option("includeHeaders", "true")
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers");

// Subscribe to multiple topics
Dataset<Row> df = spark
.readStream()
Expand Down Expand Up @@ -116,6 +139,16 @@ df = spark \
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")

# Subscribe to 1 topic, with headers
val df = spark \
.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "host1:port1,host2:port2") \
.option("subscribe", "topic1") \
.option("includeHeaders", "true") \
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")

# Subscribe to multiple topics
df = spark \
.readStream \
Expand Down Expand Up @@ -286,6 +319,10 @@ Each row in the source has the following schema:
<td>timestampType</td>
<td>int</td>
</tr>
<tr>
<td>headers (optional)</td>
<td>array</td>
</tr>
</table>

The following options must be set for the Kafka source
Expand Down Expand Up @@ -425,6 +462,13 @@ The following configurations are optional:
issues, set the Kafka consumer session timeout (by setting option "kafka.session.timeout.ms") to
be very small. When this is set, option "groupIdPrefix" will be ignored.</td>
</tr>
<tr>
<td>includeHeaders</td>
<td>boolean</td>
<td>false</td>
<td>streaming and batch</td>
<td>Whether to include the Kafka headers in the row.</td>
</tr>
</table>

### Consumer Caching
Expand Down Expand Up @@ -522,6 +566,10 @@ The Dataframe being written to Kafka should have the following columns in schema
<td>value (required)</td>
<td>string or binary</td>
</tr>
<tr>
<td>headers (optional)</td>
<td>array</td>
</tr>
<tr>
<td>topic (*optional)</td>
<td>string</td>
Expand Down Expand Up @@ -559,6 +607,13 @@ The following configurations are optional:
<td>Sets the topic that all rows will be written to in Kafka. This option overrides any
topic column that may exist in the data.</td>
</tr>
<tr>
<td>includeHeaders</td>
<td>boolean</td>
<td>false</td>
<td>streaming and batch</td>
<td>Whether to include the Kafka headers in the row.</td>
</tr>
</table>

### Creating a Kafka Sink for Streaming Queries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ private[kafka010] class KafkaBatch(
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
endingOffsets: KafkaOffsetRangeLimit)
endingOffsets: KafkaOffsetRangeLimit,
includeHeaders: Boolean)
extends Batch with Logging {
assert(startingOffsets != LatestOffsetRangeLimit,
"Starting offset not allowed to be set to latest offsets.")
Expand Down Expand Up @@ -90,7 +91,7 @@ private[kafka010] class KafkaBatch(
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
offsetRanges.map { range =>
new KafkaBatchInputPartition(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
}.toArray
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ private[kafka010] case class KafkaBatchInputPartition(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends InputPartition
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends InputPartition

private[kafka010] object KafkaBatchReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaBatchInputPartition]
KafkaBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs,
p.failOnDataLoss)
p.failOnDataLoss, p.includeHeaders)
}
}

Expand All @@ -44,12 +45,14 @@ private case class KafkaBatchPartitionReader(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends PartitionReader[InternalRow] with Logging {
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging {

private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams)

private val rangeToRead = resolveRange(offsetRange)
private val converter = new KafkaRecordToUnsafeRowConverter
private val unsafeRowProjector = new KafkaRecordToRowConverter()
.toUnsafeRowProjector(includeHeaders)

private var nextOffset = rangeToRead.fromOffset
private var nextRow: UnsafeRow = _
Expand All @@ -58,7 +61,7 @@ private case class KafkaBatchPartitionReader(
if (nextOffset < rangeToRead.untilOffset) {
val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss)
if (record != null) {
nextRow = converter.toUnsafeRow(record)
nextRow = unsafeRowProjector(record)
nextOffset = record.offset + 1
true
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset}
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
Expand All @@ -56,6 +56,7 @@ class KafkaContinuousStream(

private[kafka010] val pollTimeoutMs =
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)

// Initialized when creating reader factories. If this diverges from the partitions at the latest
// offsets, we need to reconfigure.
Expand Down Expand Up @@ -88,7 +89,7 @@ class KafkaContinuousStream(
if (deletedPartitions.nonEmpty) {
val message = if (
offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
Expand All @@ -102,7 +103,7 @@ class KafkaContinuousStream(
startOffsets.toSeq.map {
case (topicPartition, start) =>
KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
}.toArray
}

Expand Down Expand Up @@ -153,19 +154,22 @@ class KafkaContinuousStream(
* @param pollTimeoutMs The timeout for Kafka consumer polling.
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
* @param includeHeaders Flag indicating whether to include Kafka records' headers.
*/
case class KafkaContinuousInputPartition(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends InputPartition
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends InputPartition

object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaContinuousInputPartition]
new KafkaContinuousPartitionReader(
p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss)
p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs,
p.failOnDataLoss, p.includeHeaders)
}
}

Expand All @@ -184,9 +188,11 @@ class KafkaContinuousPartitionReader(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] {
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends ContinuousPartitionReader[InternalRow] {
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams)
private val converter = new KafkaRecordToUnsafeRowConverter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd have KafkaRecordToRowProjector (either class or object, but object would be fine) instead and move every projectors newly added in KafkaOffsetReader to there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So +1 on your proposal. The proposed name is just 2 cents, and I'm not sure which name fits best.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HeartSaVioR Great. Let's continue on discussing which name would be the best.

private val unsafeRowProjector = new KafkaRecordToRowConverter()
.toUnsafeRowProjector(includeHeaders)

private var nextKafkaOffset = startOffset
private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
Expand Down Expand Up @@ -225,7 +231,7 @@ class KafkaContinuousPartitionReader(
}

override def get(): UnsafeRow = {
converter.toUnsafeRow(currentRecord)
unsafeRowProjector(currentRecord)
}

override def getOffset(): KafkaSourcePartitionOffset = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.UninterruptibleThread

Expand Down Expand Up @@ -64,6 +64,8 @@ private[kafka010] class KafkaMicroBatchStream(
private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)

private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)

private val rangeCalculator = KafkaOffsetRangeCalculator(options)

private var endPartitionOffsets: KafkaSourceOffset = _
Expand Down Expand Up @@ -112,7 +114,7 @@ private[kafka010] class KafkaMicroBatchStream(
if (deletedPartitions.nonEmpty) {
val message =
if (kafkaOffsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
Expand Down Expand Up @@ -146,7 +148,8 @@ private[kafka010] class KafkaMicroBatchStream(

// Generate factories based on the offset ranges
offsetRanges.map { range =>
KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs,
failOnDataLoss, includeHeaders)
}.toArray
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.kafka.common.TopicPartition

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}

/**
Expand Down Expand Up @@ -421,16 +420,3 @@ private[kafka010] class KafkaOffsetReader(
_consumer = null // will automatically get reinitialized again
}
}

private[kafka010] object KafkaOffsetReader {

def kafkaSchema: StructType = StructType(Seq(
StructField("key", BinaryType),
StructField("value", BinaryType),
StructField("topic", StringType),
StructField("partition", IntegerType),
StructField("offset", LongType),
StructField("timestamp", TimestampType),
StructField("timestampType", IntegerType)
))
}
Loading