Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-attempt partitionedLoad #140

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ semanticdbEnabled := true
semanticdbVersion := scalafixSemanticdb.revision

tpolecatScalacOptions ++= Set(ScalacOptions.source3)
tpolecatScalacOptions ~= (_.filterNot(Set(ScalacOptions.warnValueDiscard)))
lacarvalho91 marked this conversation as resolved.
Show resolved Hide resolved

ThisBuild / scalacOptions ++= Seq("-explaintypes", "-Wconf:msg=annotation:silent")

Expand Down
113 changes: 85 additions & 28 deletions src/main/scala/uk/sky/kafka/topicloader/TopicLoader.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package uk.sky.kafka.topicloader

import java.lang.{Long => JLong}
import java.util.{List => JList, Map => JMap, Optional}
import java.lang.Long as JLong
import java.util.{List as JList, Map as JMap, Optional}

import akka.Done
import akka.actor.ActorSystem
Expand All @@ -10,18 +10,19 @@ import akka.kafka.{ConsumerSettings, Subscriptions}
import akka.stream.OverflowStrategy
import akka.stream.scaladsl.{Flow, Keep, Source}
import cats.data.NonEmptyList
import cats.syntax.bifunctor._
import cats.syntax.option._
import cats.syntax.show._
import cats.syntax.bifunctor.*
import cats.syntax.option.*
import cats.syntax.show.*
import cats.{Bifunctor, Show}
import com.typesafe.scalalogging.LazyLogging
import org.apache.kafka.clients.consumer._
import org.apache.kafka.clients.consumer.*
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization._
import org.apache.kafka.common.serialization.*
import uk.sky.kafka.topicloader.config.{Config, TopicLoaderConfig}

import scala.concurrent.Future
import scala.jdk.CollectionConverters._
import scala.jdk.CollectionConverters.*
import scala.util.Using

object TopicLoader extends TopicLoader {
private[topicloader] case class LogOffsets(lowest: Long, highest: Long)
Expand Down Expand Up @@ -76,30 +77,74 @@ trait TopicLoader extends LazyLogging {
strategy: LoadTopicStrategy,
maybeConsumerSettings: Option[ConsumerSettings[Array[Byte], Array[Byte]]] = None
)(implicit system: ActorSystem): Source[ConsumerRecord[K, V], Future[Consumer.Control]] = {
val config =
Config
.loadOrThrow(system.settings.config)
.topicLoader
val config = Config.loadOrThrow(system.settings.config).topicLoader
load(logOffsetsForTopics(topics, strategy, config), config, maybeConsumerSettings)
}

def partitionedLoad[K : Deserializer, V : Deserializer](
topics: NonEmptyList[String],
strategy: LoadTopicStrategy,
maybeConsumerSettings: Option[ConsumerSettings[Array[Byte], Array[Byte]]] = None
)(implicit
system: ActorSystem
): Source[(TopicPartition, Source[ConsumerRecord[K, V], Future[Consumer.Control]]), Consumer.Control] = {
val config = Config.loadOrThrow(system.settings.config).topicLoader
Consumer
.plainPartitionedSource(
consumerSettings(maybeConsumerSettings, config),
Subscriptions.topics(topics.toList.toSet)
)
.buffer(config.bufferSize.value, OverflowStrategy.backpressure)
.idleTimeout(config.idleTimeout)
.map { case (partition, _) =>
(
partition,
load[K, V](
logOffsetsForPartitions(NonEmptyList.one(partition), strategy, config),
config,
maybeConsumerSettings
)
)
}
}

def partitionedLoadAndRun[K : Deserializer, V : Deserializer](
topics: NonEmptyList[String],
maybeConsumerSettings: Option[ConsumerSettings[Array[Byte], Array[Byte]]] = None
)(implicit
system: ActorSystem
): Source[
(TopicPartition, Source[ConsumerRecord[K, V], (Future[Done], Future[Consumer.Control])]),
Consumer.Control
] = {
val config = Config.loadOrThrow(system.settings.config).topicLoader

Consumer
.plainPartitionedSource(
consumerSettings(maybeConsumerSettings, config),
Subscriptions.topics(topics.toList.toSet)
)
.map { case (partition, _) =>
(
partition,
loadAndRun(
logOffsetsForPartitions(NonEmptyList.one(partition), LoadAll, config),
config,
maybeConsumerSettings
)
)
}
}

/** Source that loads the specified topics from the beginning. When the latest current offsets are reached, the
* materialised value is completed, and the stream continues.
*/
def loadAndRun[K : Deserializer, V : Deserializer](
topics: NonEmptyList[String],
maybeConsumerSettings: Option[ConsumerSettings[Array[Byte], Array[Byte]]] = None
)(implicit system: ActorSystem): Source[ConsumerRecord[K, V], (Future[Done], Future[Consumer.Control])] = {
val config = Config.loadOrThrow(system.settings.config).topicLoader
val logOffsetsF = logOffsetsForTopics(topics, LoadAll, config)
val postLoadingSource = Source.futureSource(logOffsetsF.map { logOffsets =>
val highestOffsets = logOffsets.map { case (p, o) => p -> o.highest }
kafkaSource[K, V](highestOffsets, config, maybeConsumerSettings)
}(system.dispatcher))

load[K, V](logOffsetsF, config, maybeConsumerSettings)
.watchTermination()(Keep.right)
.concatMat(postLoadingSource)(Keep.both)
val config = Config.loadOrThrow(system.settings.config).topicLoader
loadAndRun(logOffsetsForTopics(topics, LoadAll, config), config, maybeConsumerSettings)
}

protected def logOffsetsForPartitions(
Expand Down Expand Up @@ -157,6 +202,21 @@ trait TopicLoader extends LazyLogging {
}
}

protected def loadAndRun[K : Deserializer, V : Deserializer](
logOffsets: Future[Map[TopicPartition, LogOffsets]],
config: TopicLoaderConfig,
maybeConsumerSettings: Option[ConsumerSettings[Array[Byte], Array[Byte]]]
)(implicit system: ActorSystem): Source[ConsumerRecord[K, V], (Future[Done], Future[Consumer.Control])] = {
val postLoadingSource = Source.futureSource(logOffsets.map { logOffsets =>
val highestOffsets = logOffsets.map { case (p, o) => p -> o.highest }
kafkaSource[K, V](highestOffsets, config, maybeConsumerSettings)
}(system.dispatcher))

load[K, V](logOffsets, config, maybeConsumerSettings)
.watchTermination()(Keep.right)
.concatMat(postLoadingSource)(Keep.both)
}

protected def load[K : Deserializer, V : Deserializer](
logOffsets: Future[Map[TopicPartition, LogOffsets]],
config: TopicLoaderConfig,
Expand Down Expand Up @@ -202,7 +262,7 @@ trait TopicLoader extends LazyLogging {
startingOffsets: Map[TopicPartition, Long],
config: TopicLoaderConfig,
maybeConsumerSettings: Option[ConsumerSettings[Array[Byte], Array[Byte]]]
)(implicit system: ActorSystem) =
)(implicit system: ActorSystem): Source[ConsumerRecord[K, V], Consumer.Control] =
Consumer
.plainSource(consumerSettings(maybeConsumerSettings, config), Subscriptions.assignmentWithOffset(startingOffsets))
.buffer(config.bufferSize.value, OverflowStrategy.backpressure)
Expand All @@ -224,11 +284,8 @@ trait TopicLoader extends LazyLogging {

private def withStandaloneConsumer[T](
settings: ConsumerSettings[Array[Byte], Array[Byte]]
)(f: Consumer[Array[Byte], Array[Byte]] => T): T = {
val consumer = settings.createKafkaConsumer()
try f(consumer)
finally consumer.close()
}
)(f: Consumer[Array[Byte], Array[Byte]] => T): T =
Using.resource(settings.createKafkaConsumer())(f)

private def offsetsFrom(partitions: List[TopicPartition])(
f: JList[TopicPartition] => JMap[TopicPartition, JLong]
Expand Down
16 changes: 13 additions & 3 deletions src/test/scala/base/IntegrationSpecBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ import akka.actor.ActorSystem
import akka.kafka.ConsumerSettings
import akka.util.Timeout
import cats.data.NonEmptyList
import cats.syntax.option._
import cats.syntax.option.*
import com.typesafe.config.ConfigFactory
import io.github.embeddedkafka.Codecs.{stringDeserializer, stringSerializer}
import io.github.embeddedkafka.{EmbeddedKafka, EmbeddedKafkaConfig}
import org.apache.kafka.clients.CommonClientConfigs
import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, ConsumerRecord, ConsumerRecords}
import org.apache.kafka.clients.producer.ProducerConfig
import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
import org.apache.kafka.common.TopicPartition
import org.scalatest.Assertion
import org.scalatest.concurrent.Eventually
import utils.RandomPort

import scala.annotation.tailrec
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.jdk.CollectionConverters.*

abstract class IntegrationSpecBase extends WordSpecBase with Eventually {

Expand Down Expand Up @@ -79,6 +79,11 @@ abstract class IntegrationSpecBase extends WordSpecBase with Eventually {

def recordToTuple[K, V](record: ConsumerRecord[K, V]): (K, V) = (record.key(), record.value())

def sourceFromPartition[T](
sources: Seq[(TopicPartition, T)],
partition: Int
): T = sources.find { case (part, _) => part.partition() == partition }.map { case (_, source) => source }.value

val testTopic1 = "load-state-topic-1"
val testTopic2 = "load-state-topic-2"
val testTopicPartitions = 5
Expand All @@ -96,6 +101,11 @@ abstract class IntegrationSpecBase extends WordSpecBase with Eventually {
publishToKafka(topic, messages)
publishToKafka(topic, filler)
}

def publishToKafka(topic: String, partition: Int, messages: Seq[(String, String)]): Unit =
messages.foreach { case (k, v) =>
publishToKafka(new ProducerRecord[String, String](topic, partition, k, v))
}
}

trait KafkaConsumer { this: TestContext =>
Expand Down
73 changes: 71 additions & 2 deletions src/test/scala/integration/TopicLoaderIntSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ package integration

import java.util.concurrent.TimeoutException as JavaTimeoutException

import akka.Done
import akka.actor.ActorSystem
import akka.kafka.ConsumerSettings
import akka.stream.scaladsl.{Keep, Sink}
import akka.kafka.scaladsl.Consumer
import akka.stream.scaladsl.{Keep, Sink, Source}
import akka.stream.testkit.scaladsl.TestSink
import base.IntegrationSpecBase
import cats.data.NonEmptyList
import cats.syntax.option.*
import com.typesafe.config.{ConfigException, ConfigFactory}
import io.github.embeddedkafka.Codecs.{stringDeserializer, stringSerializer}
import org.apache.kafka.clients.consumer.ConsumerConfig
import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord}
import org.apache.kafka.common.errors.TimeoutException as KafkaTimeoutException
import org.apache.kafka.common.serialization.ByteArrayDeserializer
import org.scalatest.prop.TableDrivenPropertyChecks.*
Expand Down Expand Up @@ -59,6 +61,32 @@ class TopicLoaderIntSpec extends IntegrationSpecBase {
loadedRecords.map(recordToTuple) should contain theSameElementsAs published
}
}

"stream all records from all topics and emit a source per partition" in new TestContext {
val topics = NonEmptyList.one(testTopic1)
val (forPartition1, forPartition2) = records(1 to 15).splitAt(10)
val partitions: Long = 2

withRunningKafka {
createCustomTopics(topics, partitions.toInt)

publishToKafka(testTopic1, 0, forPartition1)
publishToKafka(testTopic1, 1, forPartition2)

val partitionedSources =
TopicLoader.partitionedLoad[String, String](topics, strategy).take(partitions).runWith(Sink.seq).futureValue

sourceFromPartition(partitionedSources, 0)
.runWith(Sink.seq)
.futureValue
.map(recordToTuple) should contain theSameElementsAs forPartition1

sourceFromPartition(partitionedSources, 1)
.runWith(Sink.seq)
.futureValue
.map(recordToTuple) should contain theSameElementsAs forPartition2
}
}
}

"using LoadCommitted strategy" should {
Expand Down Expand Up @@ -247,6 +275,47 @@ class TopicLoaderIntSpec extends IntegrationSpecBase {
}
}
}

"execute callback when finished loading and keep streaming per partition" in new TestContext {
val (preLoadPart1, postLoadPart1) = records(1 to 15).splitAt(10)
val (preLoadPart2, postLoadPart2) = records(16 to 30).splitAt(10)
val partitions: Long = 2

withRunningKafka {
createCustomTopic(testTopic1, partitions = partitions.toInt)

publishToKafka(testTopic1, 0, preLoadPart1)
publishToKafka(testTopic1, 1, preLoadPart2)

val partitionedStream = TopicLoader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this test fail if this was to use just loadAndRun?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean if we publish to one partition and then the other? It still passes with the loadAndRun:

    "execute callback when finished loading and keep streaming per partition" in new TestContext {
      val (preLoadPart1, postLoadPart1) = records(1 to 15).splitAt(10)
      val (preLoadPart2, postLoadPart2) = records(16 to 30).splitAt(10)
      val partitions: Long              = 2

      withRunningKafka {
        createCustomTopic(testTopic1, partitions = partitions.toInt)

        publishToKafka(testTopic1, 0, preLoadPart1)
        publishToKafka(testTopic1, 1, preLoadPart2)

        val ((callback, _), recordsProbe) =
          TopicLoader.loadAndRun[String, String](NonEmptyList.one(testTopic1)).toMat(TestSink.probe)(Keep.both).run()

        recordsProbe.request(
          preLoadPart1.size.toLong + postLoadPart1.size.toLong + preLoadPart2.size.toLong + postLoadPart2.size.toLong
        )
        recordsProbe
          .expectNextN(preLoadPart1.size.toLong + preLoadPart2.size.toLong)
          .map(recordToTuple) should contain theSameElementsAs preLoadPart1 ++ preLoadPart2

        whenReady(callback) { _ =>
          publishToKafka(testTopic1, 0, postLoadPart1)
          publishToKafka(testTopic1, 1, postLoadPart2)

          recordsProbe
            .expectNextN(postLoadPart1.size.toLong + postLoadPart2.size.toLong)
            .map(recordToTuple) should contain theSameElementsAs  postLoadPart1 ++ postLoadPart2
        }
      }
    }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah so my point is i don't think we have a test that is proving this partitioned load functionality, considering the test passes even with loadAndRun

.partitionedLoadAndRun[String, String](NonEmptyList.one(testTopic1))
.take(partitions)
.runWith(Sink.seq)
.futureValue

def validate(
source: Source[ConsumerRecord[String, String], (Future[Done], Future[Consumer.Control])],
partition: Int,
preLoad: Seq[(String, String)],
postLoad: Seq[(String, String)]
): Unit = {

val ((callback, _), recordsProbe) = source.toMat(TestSink.probe)(Keep.both).run()

recordsProbe.request(preLoad.size.toLong + postLoad.size.toLong)
recordsProbe.expectNextN(preLoad.size.toLong).map(recordToTuple) shouldBe preLoad

whenReady(callback) { _ =>
publishToKafka(testTopic1, partition, postLoad)

recordsProbe.expectNextN(postLoad.size.toLong).map(recordToTuple) shouldBe postLoad
}
}

validate(sourceFromPartition(partitionedStream, 0), 0, preLoadPart1, postLoadPart1)
validate(sourceFromPartition(partitionedStream, 1), 1, preLoadPart2, postLoadPart2)
}
}
}

"consumerSettings" should {
Expand Down