Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
package tech.figure.kafka.coroutines.channels

import tech.figure.kafka.loggingConsumerRebalanceListener
import tech.figure.kafka.records.CommitConsumerRecord
import tech.figure.kafka.records.UnAckedConsumerRecordImpl
import tech.figure.kafka.records.UnAckedConsumerRecords
import java.util.concurrent.atomic.AtomicInteger
import kotlin.concurrent.thread
import kotlin.time.Duration
import kotlin.time.ExperimentalTime
import kotlin.time.toJavaDuration
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.ExperimentalCoroutinesApi
Expand All @@ -21,7 +16,73 @@ import kotlinx.coroutines.selects.SelectClause1
import mu.KotlinLogging
import org.apache.kafka.clients.consumer.Consumer
import org.apache.kafka.clients.consumer.ConsumerRebalanceListener
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.clients.consumer.ConsumerRecords
import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import tech.figure.kafka.loggingConsumerRebalanceListener
import tech.figure.kafka.records.CommitConsumerRecord
import tech.figure.kafka.records.UnAckedConsumerRecordImpl
import tech.figure.kafka.records.UnAckedConsumerRecords

internal fun <K, V> List<ConsumerRecord<K, V>>.toConsumerRecords() =
groupBy { TopicPartition(it.topic(), it.partition()) }.let(::ConsumerRecords)

/**
* Default is to create a committable consumer channel for unacknowledged record processing.
*
* @see [kafkaAckConsumerChannel]
*/
fun <K, V> kafkaConsumerChannel(
consumerProperties: Map<String, Any>,
topics: Set<String>,
name: String = "kafka-channel",
pollInterval: Duration = DEFAULT_POLL_INTERVAL,
consumer: Consumer<K, V> = KafkaConsumer(consumerProperties),
rebalanceListener: ConsumerRebalanceListener = loggingConsumerRebalanceListener(),
init: Consumer<K, V>.() -> Unit = { subscribe(topics, rebalanceListener) },
): ReceiveChannel<UnAckedConsumerRecords<K, V>> = kafkaAckConsumerChannel(consumerProperties, topics, name, pollInterval, consumer, rebalanceListener, init)

/**
* Create a [ReceiveChannel] for [ConsumerRecords] from kafka.
*
* @param consumerProperties Kafka consumer settings for this channel.
* @param topics Topics to subscribe to. Can be overridden via custom `init` parameter.
* @param name The thread pool's base name for this consumer.
* @param pollInterval Interval for kafka consumer [Consumer.poll] method calls.
* @param consumer The instantiated [Consumer] to use to receive from kafka.
* @param init Callback for initializing the [Consumer].
* @return A non-running [KafkaConsumerChannel] instance that must be started via
* [KafkaConsumerChannel.start].

*/
fun <K, V> kafkaNoAckConsumerChannel(
consumerProperties: Map<String, Any>,
topics: Set<String>,
name: String = "kafka-channel",
pollInterval: Duration = DEFAULT_POLL_INTERVAL,
consumer: Consumer<K, V> = KafkaConsumer(consumerProperties),
rebalanceListener: ConsumerRebalanceListener = loggingConsumerRebalanceListener(),
init: Consumer<K, V>.() -> Unit = { subscribe(topics, rebalanceListener) },
): ReceiveChannel<ConsumerRecords<K, V>> {
return object :
KafkaConsumerChannel<K, V, ConsumerRecords<K, V>>(
consumerProperties,
topics,
name,
pollInterval,
consumer,
init
) {
override suspend fun preProcessPollSet(
records: ConsumerRecords<K, V>,
context: MutableMap<String, Any>
): List<ConsumerRecords<K, V>> {
return listOf(records)
}
}
}

/**
* Create a [ReceiveChannel] for unacknowledged consumer records from kafka.
Expand All @@ -32,9 +93,10 @@ import org.apache.kafka.clients.consumer.KafkaConsumer
* @param pollInterval Interval for kafka consumer [Consumer.poll] method calls.
* @param consumer The instantiated [Consumer] to use to receive from kafka.
* @param init Callback for initializing the [Consumer].
* @return A non-running [KafkaConsumerChannel] instance that must be started via [KafkaConsumerChannel.start].
* @return A non-running [KafkaConsumerChannel] instance that must be started via
* [KafkaConsumerChannel.start].
*/
fun <K, V> kafkaConsumerChannel(
fun <K, V> kafkaAckConsumerChannel(
consumerProperties: Map<String, Any>,
topics: Set<String>,
name: String = "kafka-channel",
Expand All @@ -43,20 +105,95 @@ fun <K, V> kafkaConsumerChannel(
rebalanceListener: ConsumerRebalanceListener = loggingConsumerRebalanceListener(),
init: Consumer<K, V>.() -> Unit = { subscribe(topics, rebalanceListener) },
): ReceiveChannel<UnAckedConsumerRecords<K, V>> {
return KafkaConsumerChannel(consumerProperties, topics, name, pollInterval, consumer, init).also {
Runtime.getRuntime().addShutdownHook(
Thread {
it.cancel()
return KafkaAckConsumerChannel(
consumerProperties,
topics,
name,
pollInterval,
consumer,
init
).also { Runtime.getRuntime().addShutdownHook(Thread { it.cancel() }) }
}

/**
* Acking kafka [Consumer] object implementing the [ReceiveChannel] methods.
*
* Note: Must operate in a bound thread context regardless of coroutine assignment due to internal
* kafka threading limitations for poll fetches, acknowledgements, and sends.
*
* @param consumerProperties Kafka consumer settings for this channel.
* @param topics Topics to subscribe to. Can be overridden via custom `init` parameter.
* @param name The thread pool's base name for this consumer.
* @param pollInterval Interval for kafka consumer [Consumer.poll] method calls.
* @param consumer The instantiated [Consumer] to use to receive from kafka.
* @param init Callback for initializing the [Consumer].
*/
internal class KafkaAckConsumerChannel<K, V>(
consumerProperties: Map<String, Any>,
topics: Set<String>,
name: String,
pollInterval: Duration,
consumer: Consumer<K, V>,
init: Consumer<K, V>.() -> Unit
) :
KafkaConsumerChannel<K, V, UnAckedConsumerRecords<K, V>>(
consumerProperties,
topics,
name,
pollInterval,
consumer,
init
) {
override suspend fun preProcessPollSet(
records: ConsumerRecords<K, V>,
context: MutableMap<String, Any>,
): List<UnAckedConsumerRecords<K, V>> {
log.trace { "preProcessPollSet(${records.count()})" }
val ackChannel =
Channel<CommitConsumerRecord>(capacity = records.count()).also {
context["ack-channel"] = it
}
)
val unackedRecords =
records
.groupBy { "${it.topic()}-${it.partition()}" }
.map {
val timestamp = System.currentTimeMillis()
val records =
it.value.map { UnAckedConsumerRecordImpl(it, ackChannel, timestamp) }
UnAckedConsumerRecords(records)
}
return unackedRecords
}

@Suppress("unchecked_cast")
override suspend fun postProcessPollSet(
records: List<UnAckedConsumerRecords<K, V>>,
context: Map<String, Any>
) {
log.trace { "postProcessPollSet(records:${records.sumOf { it.count() } })" }
val ackChannel = context["ack-channel"]!! as Channel<CommitConsumerRecord>
for (rs in records) {
if (rs.records.isNotEmpty()) {
val count = AtomicInteger(rs.records.size)
while (count.getAndDecrement() > 0) {
log.trace { "waiting for ${count.get()} commits" }
val it = ackChannel.receive()
log.trace { "sending to broker ack(${it.duration.toMillis()}ms):${it.asCommitable()}" }
commit(it)
log.trace { "acking the commit back to flow" }
it.commitAck.send(Unit)
}
}
}
ackChannel.close()
}
}

/**
* Kafka [Consumer] object implementing the [ReceiveChannel] methods.
* Base kafka [Consumer] object implementing the [ReceiveChannel] methods.
*
* Note: Must operate in a bound thread context regardless of coroutine assignment due to internal kafka threading
* limitations for poll fetches, acknowledgements, and sends.
* Note: Must operate in a bound thread context regardless of coroutine assignment due to internal
* kafka threading limitations for poll fetches, acknowledgements, and sends.
*
* @param consumerProperties Kafka consumer settings for this channel.
* @param topics Topics to subscribe to. Can be overridden via custom `init` parameter.
Expand All @@ -65,40 +202,46 @@ fun <K, V> kafkaConsumerChannel(
* @param consumer The instantiated [Consumer] to use to receive from kafka.
* @param init Callback for initializing the [Consumer].
*/
open class KafkaConsumerChannel<K, V>(
abstract class KafkaConsumerChannel<K, V, R>(
consumerProperties: Map<String, Any>,
topics: Set<String> = emptySet(),
name: String = "kafka-channel",
private val pollInterval: Duration = DEFAULT_POLL_INTERVAL,
private val consumer: Consumer<K, V> = KafkaConsumer(consumerProperties),
private val init: Consumer<K, V>.() -> Unit = { subscribe(topics) },
) : ReceiveChannel<UnAckedConsumerRecords<K, V>> {
) : ReceiveChannel<R> {
companion object {
private val threadCounter = AtomicInteger(0)
}

private val log = KotlinLogging.logger {}
protected val log = KotlinLogging.logger {}
private val thread =
thread(name = "$name-${threadCounter.getAndIncrement()}", block = { run() }, isDaemon = true, start = false)
private val sendChannel = Channel<UnAckedConsumerRecords<K, V>>(Channel.UNLIMITED)

private inline fun <T> Channel<T>.use(block: (Channel<T>) -> Unit) {
try {
block(this)
close()
} catch (e: Throwable) {
close(e)
}
}
thread(
name = "$name-${threadCounter.getAndIncrement()}",
block = { run() },
isDaemon = true,
start = false
)
val sendChannel = Channel<R>(Channel.UNLIMITED)

@OptIn(ExperimentalTime::class)
private fun <K, V> Consumer<K, V>.poll(duration: Duration) =
poll(duration.toJavaDuration())
private fun <K, V> Consumer<K, V>.poll(duration: Duration) = poll(duration.toJavaDuration())

private fun <T, L : Iterable<T>> L.ifEmpty(block: () -> L): L =
if (count() == 0) block() else this

@OptIn(ExperimentalCoroutinesApi::class, ExperimentalTime::class)
protected abstract suspend fun preProcessPollSet(
records: ConsumerRecords<K, V>,
context: MutableMap<String, Any>
): List<R>

protected open suspend fun postProcessPollSet(records: List<R>, context: Map<String, Any>) {}

protected fun commit(record: CommitConsumerRecord): OffsetAndMetadata {
consumer.commitSync(record.asCommitable())
return record.offsetAndMetadata
}

@OptIn(ExperimentalCoroutinesApi::class)
fun run() {
consumer.init()

Expand All @@ -108,32 +251,31 @@ open class KafkaConsumerChannel<K, V>(
try {
while (!sendChannel.isClosedForSend) {
log.trace("poll(topics:${consumer.subscription()}) ...")
val polled = consumer.poll(Duration.ZERO).ifEmpty { consumer.poll(pollInterval) }
val polled =
consumer.poll(Duration.ZERO).ifEmpty { consumer.poll(pollInterval) }
val polledCount = polled.count()
if (polledCount == 0) {
continue
}

log.trace("poll(topics:${consumer.subscription()}) got $polledCount records.")
Channel<CommitConsumerRecord>(capacity = polled.count()).use { ackChannel ->
for (it in polled.groupBy { "${it.topic()}-${it.partition()}" }) {
val timestamp = System.currentTimeMillis()
val records = it.value.map {
UnAckedConsumerRecordImpl(it, ackChannel, timestamp)
}
sendChannel.send(UnAckedConsumerRecords(records))
}

if (polledCount > 0) {
val count = AtomicInteger(polledCount)
while (count.getAndDecrement() > 0) {
val it = ackChannel.receive()
log.debug { "ack(${it.duration.toMillis()}ms):${it.asCommitable()}" }
consumer.commitSync(it.asCommitable())
it.commitAck.send(Unit)
}
}
}

// Group by topic-partition to guarantee ordering.
val records =
polled
.groupBy { "${it.topic()}-${it.partition()}" }
.values
.map { it.toConsumerRecords() }

// Convert to internal types.
val context = mutableMapOf<String, Any>()
val processSet = records.map { preProcessPollSet(it, context) }

// Send down the pipeline for processing
processSet
.onEach { it.map { sendChannel.send(it) } }
// Clean up any processing.
.map { postProcessPollSet(it, context) }
}
} finally {
log.info("${coroutineContext.job} shutting down consumer thread")
Expand All @@ -142,7 +284,9 @@ open class KafkaConsumerChannel<K, V>(
consumer.unsubscribe()
consumer.close()
} catch (ex: Exception) {
log.debug { "Consumer failed to be closed. It may have been closed from somewhere else." }
log.debug {
"Consumer failed to be closed. It may have been closed from somewhere else."
}
}
}
}
Expand All @@ -162,19 +306,23 @@ open class KafkaConsumerChannel<K, V>(
@ExperimentalCoroutinesApi
override val isClosedForReceive: Boolean = sendChannel.isClosedForReceive

@ExperimentalCoroutinesApi
override val isEmpty: Boolean = sendChannel.isEmpty
override val onReceive: SelectClause1<UnAckedConsumerRecords<K, V>> get() {
start()
return sendChannel.onReceive
}
@ExperimentalCoroutinesApi override val isEmpty: Boolean = sendChannel.isEmpty
override val onReceive: SelectClause1<R>
get() {
start()
return sendChannel.onReceive
}

override val onReceiveCatching: SelectClause1<ChannelResult<UnAckedConsumerRecords<K, V>>> get() {
start()
return sendChannel.onReceiveCatching
}
override val onReceiveCatching: SelectClause1<ChannelResult<R>>
get() {
start()
return sendChannel.onReceiveCatching
}

@Deprecated("Since 1.2.0, binary compatibility with versions <= 1.1.x", level = DeprecationLevel.HIDDEN)
@Deprecated(
"Since 1.2.0, binary compatibility with versions <= 1.1.x",
level = DeprecationLevel.HIDDEN
)
override fun cancel(cause: Throwable?): Boolean {
cancel(CancellationException("cancel", cause))
return true
Expand All @@ -185,22 +333,22 @@ open class KafkaConsumerChannel<K, V>(
sendChannel.cancel(cause)
}

override fun iterator(): ChannelIterator<UnAckedConsumerRecords<K, V>> {
override fun iterator(): ChannelIterator<R> {
start()
return sendChannel.iterator()
}

override suspend fun receive(): UnAckedConsumerRecords<K, V> {
override suspend fun receive(): R {
start()
return sendChannel.receive()
}

override suspend fun receiveCatching(): ChannelResult<UnAckedConsumerRecords<K, V>> {
override suspend fun receiveCatching(): ChannelResult<R> {
start()
return sendChannel.receiveCatching()
}

override fun tryReceive(): ChannelResult<UnAckedConsumerRecords<K, V>> {
override fun tryReceive(): ChannelResult<R> {
start()
return sendChannel.tryReceive()
}
Expand Down