diff --git a/src/main/scala/uk/sky/kafka/topicloader/TopicLoader.scala b/src/main/scala/uk/sky/kafka/topicloader/TopicLoader.scala index b3c49327..6b86789a 100644 --- a/src/main/scala/uk/sky/kafka/topicloader/TopicLoader.scala +++ b/src/main/scala/uk/sky/kafka/topicloader/TopicLoader.scala @@ -168,6 +168,8 @@ trait TopicLoader extends LazyLogging { def topicDataSource(offsets: Map[TopicPartition, LogOffsets]): Source[ConsumerRecord[K, V], Consumer.Control] = { offsets.foreach { case (partition, offset) => logger.info(s"${offset.show} for $partition") } + val partitions = offsets.keys + val nonEmptyOffsets = offsets.filter { case (_, o) => o.highest > o.lowest } val lowestOffsets = nonEmptyOffsets.map { case (p, o) => p -> o.lowest } val allHighestOffsets = @@ -184,19 +186,19 @@ trait TopicLoader extends LazyLogging { .via(filterBelowHighestOffset) .wireTap(topicLoaderMetrics.onRecord[K, V]) .mapMaterializedValue { mat => - topicLoaderMetrics.onLoading() + partitions.foreach(topicLoaderMetrics.onLoading) mat } .watchTermination() { case (mat, terminationF) => terminationF.onComplete( _.fold( e => { - logger.error(s"Error occurred while loading data from ${offsets.keys.show}", e) - topicLoaderMetrics.onError() + logger.error(s"Error occurred while loading data from ${partitions.show}", e) + partitions.foreach(topicLoaderMetrics.onError) }, _ => { - logger.info(s"Successfully loaded data from ${offsets.keys.show}") - topicLoaderMetrics.onLoaded() + logger.info(s"Successfully loaded data from ${partitions.show}") + partitions.foreach(topicLoaderMetrics.onLoaded) } ) )(system.dispatcher) diff --git a/src/main/scala/uk/sky/kafka/topicloader/metrics/TopicLoaderMetrics.scala b/src/main/scala/uk/sky/kafka/topicloader/metrics/TopicLoaderMetrics.scala index b7084a2b..952626cb 100644 --- a/src/main/scala/uk/sky/kafka/topicloader/metrics/TopicLoaderMetrics.scala +++ b/src/main/scala/uk/sky/kafka/topicloader/metrics/TopicLoaderMetrics.scala @@ -1,16 +1,17 @@ package uk.sky.kafka.topicloader.metrics import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition trait TopicLoaderMetrics { def onRecord[K, V](record: ConsumerRecord[K, V]): Unit - def onLoading(): Unit + def onLoading(topicPartitions: TopicPartition): Unit - def onLoaded(): Unit + def onLoaded(topicPartitions: TopicPartition): Unit - def onError(): Unit + def onError(topicPartitions: TopicPartition): Unit } @@ -18,10 +19,10 @@ object TopicLoaderMetrics { def noOp(): TopicLoaderMetrics = new TopicLoaderMetrics { override def onRecord[K, V](record: ConsumerRecord[K, V]): Unit = () - override def onLoading(): Unit = () + override def onLoading(topicPartitions: TopicPartition): Unit = () - override def onLoaded(): Unit = () + override def onLoaded(topicPartitions: TopicPartition): Unit = () - override def onError(): Unit = () + override def onError(topicPartitions: TopicPartition): Unit = () } } diff --git a/src/test/scala/integration/TopicLoaderIntSpec.scala b/src/test/scala/integration/TopicLoaderIntSpec.scala index 531ba3b9..e5a6e39f 100644 --- a/src/test/scala/integration/TopicLoaderIntSpec.scala +++ b/src/test/scala/integration/TopicLoaderIntSpec.scala @@ -9,6 +9,7 @@ import base.IntegrationSpecBase import cats.data.NonEmptyList import com.typesafe.config.ConfigFactory import io.github.embeddedkafka.Codecs.{stringDeserializer, stringSerializer} +import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.errors.TimeoutException as KafkaTimeoutException import org.scalatest.prop.TableDrivenPropertyChecks.* import org.scalatest.prop.Tables.Table @@ -177,12 +178,18 @@ class TopicLoaderIntSpec extends IntegrationSpecBase { "emit a State of Loaded once the stream has completed" in new TestContext { val mockTopicLoaderMetrics = new MockTopicLoaderMetrics() + val partitions = 2 val topics = NonEmptyList.of(testTopic1, testTopic2) val allRecords = records(1 to 15) val (forTopic1, forTopic2) = allRecords.splitAt(10) + val tps: Seq[TopicPartition] = for { + topic <- topics.toList + partition <- 0 until partitions + } yield new TopicPartition(topic, partition) + withRunningKafka { - createCustomTopics(topics) + createCustomTopics(topics, partitions) publishToKafka(testTopic1, forTopic1) publishToKafka(testTopic2, forTopic2) @@ -191,14 +198,14 @@ class TopicLoaderIntSpec extends IntegrationSpecBase { TopicLoader .load[String, String](topics, strategy, topicLoaderMetrics = mockTopicLoaderMetrics) - mockTopicLoaderMetrics.loadingState.get() shouldBe NotStarted + mockTopicLoaderMetrics.loadingState shouldBe empty loadF.runWith(Sink.foreach { _ => - mockTopicLoaderMetrics.loadingState.get() shouldBe Loading + tps.foreach(tp => mockTopicLoaderMetrics.loadingState.get(tp).value shouldBe Loading(tp)) }) eventually { - mockTopicLoaderMetrics.loadingState.get() shouldBe Loaded + tps.foreach(tp => mockTopicLoaderMetrics.loadingState.get(tp).value shouldBe Loaded(tp)) } } } @@ -209,8 +216,14 @@ class TopicLoaderIntSpec extends IntegrationSpecBase { val published = records(1 to 10) val explodingKey = published.drop(5).head._1 + val partitions = 2 + + val tps: Seq[TopicPartition] = for { + partition <- 0 until partitions + } yield new TopicPartition(testTopic1, partition) + withRunningKafka { - createCustomTopic(testTopic1) + createCustomTopic(testTopic1, partitions = partitions) publishToKafka(testTopic1, published) TopicLoader @@ -218,7 +231,7 @@ class TopicLoaderIntSpec extends IntegrationSpecBase { .runWith(errorSink(explodingKey)) eventually { - mockTopicLoaderMetrics.loadingState.get() shouldBe ErrorLoading + tps.foreach(tp => mockTopicLoaderMetrics.loadingState.get(tp).value shouldBe ErrorLoading(tp)) } } } @@ -320,10 +333,15 @@ class TopicLoaderIntSpec extends IntegrationSpecBase { "emit a State of Loaded when finished loading" in new TestContext { val mockTopicLoaderMetrics = new MockTopicLoaderMetrics() - val preLoad = records(1 to 15) + val partitions = 2 + val preLoad = records(1 to 15) + + val tps: Seq[TopicPartition] = for { + partition <- 0 until partitions + } yield new TopicPartition(testTopic1, partition) withRunningKafka { - createCustomTopic(testTopic1) + createCustomTopic(testTopic1, partitions = partitions) publishToKafka(testTopic1, preLoad) @@ -334,7 +352,7 @@ class TopicLoaderIntSpec extends IntegrationSpecBase { .run() whenReady(callback) { _ => - mockTopicLoaderMetrics.loadingState.get() shouldBe Loaded + tps.foreach(tp => mockTopicLoaderMetrics.loadingState.get(tp).value shouldBe Loaded(tp)) } } } @@ -342,11 +360,16 @@ class TopicLoaderIntSpec extends IntegrationSpecBase { "emit a State of Error if the initial loading callback fails" in new TestContext { val mockTopicLoaderMetrics = new MockTopicLoaderMetrics() + val partitions = 2 val preLoad = records(1 to 15) val explodingKey = preLoad.drop(5).head._1 + val tps: Seq[TopicPartition] = for { + partition <- 0 until partitions + } yield new TopicPartition(testTopic1, partition) + withRunningKafka { - createCustomTopic(testTopic1) + createCustomTopic(testTopic1, partitions = partitions) publishToKafka(testTopic1, preLoad) @@ -359,7 +382,7 @@ class TopicLoaderIntSpec extends IntegrationSpecBase { eventually { callback.failed - mockTopicLoaderMetrics.loadingState.get() shouldBe ErrorLoading + tps.foreach(tp => mockTopicLoaderMetrics.loadingState.get(tp).value shouldBe ErrorLoading(tp)) } } } diff --git a/src/test/scala/utils/MockTopicLoaderMetrics.scala b/src/test/scala/utils/MockTopicLoaderMetrics.scala index 02a19a15..af1638ae 100644 --- a/src/test/scala/utils/MockTopicLoaderMetrics.scala +++ b/src/test/scala/utils/MockTopicLoaderMetrics.scala @@ -3,28 +3,34 @@ package utils import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition import uk.sky.kafka.topicloader.metrics.TopicLoaderMetrics -import utils.MockTopicLoaderMetrics._ +import utils.MockTopicLoaderMetrics.* + +import scala.collection.concurrent.TrieMap class MockTopicLoaderMetrics extends TopicLoaderMetrics { val recordCounter = new AtomicInteger() - val loadingState = new AtomicReference[State](NotStarted) + val loadingState = TrieMap.empty[TopicPartition, State] override def onRecord[K, V](record: ConsumerRecord[K, V]): Unit = recordCounter.incrementAndGet() - override def onLoading(): Unit = loadingState.set(Loading) + override def onLoading(topicPartition: TopicPartition): Unit = + loadingState.put(topicPartition, Loading(topicPartition)) - override def onLoaded(): Unit = loadingState.set(Loaded) + override def onLoaded(topicPartition: TopicPartition): Unit = + loadingState.put(topicPartition, Loaded(topicPartition)) - override def onError(): Unit = loadingState.set(ErrorLoading) + override def onError(topicPartition: TopicPartition): Unit = + loadingState.put(topicPartition, ErrorLoading(topicPartition)) } object MockTopicLoaderMetrics { sealed trait State extends Product with Serializable - case object NotStarted extends State - case object Loading extends State - case object Loaded extends State - case object ErrorLoading extends State + case object NotStarted extends State + case class Loading(topicPartition: TopicPartition) extends State + case class Loaded(topicPartition: TopicPartition) extends State + case class ErrorLoading(topicPartition: TopicPartition) extends State }