Skip to content

Commit

Permalink
Merge pull request #122 from alm0ra/add-type-hints-and-refactor
Browse files Browse the repository at this point in the history
Add type hints & raises error when consumer is closed
  • Loading branch information
alm0ra authored Jul 10, 2024
2 parents c0b2cfe + 8a4cde9 commit 7b34aa8
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 27 deletions.
3 changes: 3 additions & 0 deletions docs/async-fake-aiokafka-consumer.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ fake_consumer = FakeAIOKafkaConsumer()
# Subscribe to topics
fake_consumer.subscribe(topics=['sample_topic1', 'sample_topic2'])

# start consumer
await fake_consumer.start()

# Get one message
message = await fake_consumer.getone()
```
61 changes: 37 additions & 24 deletions mockafka/aiokafka/aiokafka_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import re
import warnings
from collections.abc import Iterable, Iterator
from typing import Any
from typing import Any, Optional

from aiokafka.abc import ConsumerRebalanceListener # type: ignore[import-untyped]
from aiokafka.errors import ConsumerStoppedError # type: ignore[import-untyped]
from aiokafka.structs import ( # type: ignore[import-untyped]
ConsumerRecord,
TopicPartition,
Expand All @@ -19,9 +20,9 @@


def message_to_record(message: Message, offset: int) -> ConsumerRecord[bytes, bytes]:
topic: str | None = message.topic()
partition: int | None = message.partition()
timestamp: int | None = message.timestamp()
topic: Optional[str] = message.topic()
partition: Optional[int] = message.partition()
timestamp: Optional[int] = message.timestamp()

if topic is None or partition is None or timestamp is None:
fields = [
Expand All @@ -32,8 +33,8 @@ def message_to_record(message: Message, offset: int) -> ConsumerRecord[bytes, by
missing = ", ".join(x for x, y in fields if y is None)
raise ValueError(f"Message is missing key components: {missing}")

key_str: str | None = message.key()
value_str: str | None = message.value()
key_str: Optional[str] = message.key()
value_str: Optional[str] = message.value()

key = key_str.encode() if key_str is not None else None
value = value_str.encode() if value_str is not None else None
Expand Down Expand Up @@ -87,35 +88,39 @@ def __init__(self, *topics: str, **kwargs: Any) -> None:
self.kafka = KafkaStore()
self.consumer_store: dict[str, int] = {}
self.subscribed_topic = [x for x in topics if self.kafka.is_topic_exist(x)]
self._is_closed = True

async def start(self) -> None:
self.consumer_store = {}
self._is_closed = False

async def stop(self) -> None:
self.consumer_store = {}
self._is_closed = True

async def commit(self):
for item in self.consumer_store:
topic, partition = item.split("*")
if (
self.kafka.get_partition_first_offset(topic, partition)
<= self.consumer_store[item]
self.kafka.get_partition_first_offset(topic, partition)
<= self.consumer_store[item]
):
self.kafka.set_first_offset(
topic=topic, partition=partition, value=self.consumer_store[item]
)

self.consumer_store = {}

async def topics(self):
async def topics(self) -> list[str]:
return self.subscribed_topic

def subscribe(
self,
topics: list[str] | set[str] | tuple[str, ...] = (),
pattern: str | None = None,
listener: ConsumerRebalanceListener | None = None,
self,
topics: list[str] | set[str] | tuple[str, ...] = (),
pattern: str | None = None,
listener: Optional[ConsumerRebalanceListener] = None,
) -> None:

if topics and pattern:
raise ValueError(
"Only one of `topics` and `pattern` may be provided (not both).",
Expand Down Expand Up @@ -146,16 +151,17 @@ def subscribe(
if topic not in self.subscribed_topic:
self.subscribed_topic.append(topic)

def subscribtion(self) -> list[str]:
def subscription(self) -> list[str]:
return self.subscribed_topic

def unsubscribe(self):
def unsubscribe(self) -> None:
self.subscribed_topic = []

def _get_key(self, topic, partition) -> str:
return f"{topic}*{partition}"

def _fetch_one(self, topic: str, partition: int) -> Message | None:
def _fetch_one(self, topic: str, partition: int) -> Optional[ConsumerRecord[bytes, bytes]]:

first_offset = self.kafka.get_partition_first_offset(
topic=topic, partition=partition
)
Expand All @@ -181,9 +187,10 @@ def _fetch_one(self, topic: str, partition: int) -> Message | None:
return message_to_record(message, offset=consumer_amount)

def _fetch(
self,
partitions: Iterable[TopicPartition],
self,
partitions: Iterable[TopicPartition],
) -> Iterator[tuple[TopicPartition, ConsumerRecord[bytes, bytes]]]:

if partitions:
partitions_to_consume = list(partitions)
else:
Expand All @@ -205,19 +212,25 @@ def _fetch(
yield tp, record

async def getone(
self, *partitions: TopicPartition
) -> ConsumerRecord[bytes, bytes] | None:
self, *partitions: TopicPartition
) -> Optional[ConsumerRecord[bytes, bytes]]:
if self._is_closed:
raise ConsumerStoppedError()

for _, record in self._fetch(partitions):
return record

return None

async def getmany(
self,
*partitions: TopicPartition,
timeout_ms: int = 0,
max_records: int | None = None,
self,
*partitions: TopicPartition,
timeout_ms: int = 0,
max_records: Optional[int] = None,
) -> dict[TopicPartition, list[ConsumerRecord[bytes, bytes]]]:
if self._is_closed:
raise ConsumerStoppedError()

records = self._fetch(partitions)
if max_records is not None:
records = itertools.islice(records, max_records)
Expand Down
1 change: 1 addition & 0 deletions mockafka/decorators/aconsumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def decorator(func):
async def wrapper(*args, **kwargs):
# Create a FakeConsumer instance and subscribe to specified topics
fake_consumer = FakeAIOKafkaConsumer()
await fake_consumer.start()
fake_consumer.subscribe(topics=topics)

# Simulate message consumption
Expand Down
25 changes: 22 additions & 3 deletions tests/test_aiokafka/test_aiokafka_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest import IsolatedAsyncioTestCase

import pytest
from aiokafka.errors import ConsumerStoppedError # type: ignore[import-untyped]
from aiokafka.structs import ( # type: ignore[import-untyped]
ConsumerRecord,
TopicPartition,
Expand All @@ -20,8 +21,8 @@
@pytest.mark.asyncio
class TestAIOKAFKAFakeConsumer(IsolatedAsyncioTestCase):
def summarise(
self,
records: list[ConsumerRecord],
self,
records: list[ConsumerRecord],
) -> list[tuple[str | None, str | None]]:
return [(x.key, x.value) for x in records]

Expand Down Expand Up @@ -62,12 +63,15 @@ async def test_start(self):
# close consumer and check consumer store and consume return none
await self.consumer.stop()
self.assertEqual(self.consumer.consumer_store, {})

await self.consumer.start()
self.assertIsNone(await self.consumer.getone())

async def test_poll_without_commit(self):
self.create_topic()
await self.produce_message()
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

message = await self.consumer.getone()
self.assertEqual(message.value, b"test")
Expand All @@ -81,6 +85,7 @@ async def test_partition_specific_poll_without_commit(self):
self.create_topic()
await self.produce_message()
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

message = await self.consumer.getone(
TopicPartition(self.test_topic, 2),
Expand All @@ -96,6 +101,7 @@ async def test_poll_with_commit(self):
self.create_topic()
await self.produce_message()
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

message = await self.consumer.getone()
await self.consumer.commit()
Expand All @@ -115,6 +121,7 @@ async def test_getmany_without_commit(self):
topic=self.test_topic, partition=2, key="test2", value="test2"
)
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

# Order unknown as partition order is not predictable
messages = {
Expand Down Expand Up @@ -143,7 +150,7 @@ async def test_getmany_with_limit_without_commit(self):
topic=self.test_topic, partition=0, key="test2", value="test2"
)
self.consumer.subscribe(topics=[self.test_topic])

await self.consumer.start()
messages = {
tp: self.summarise(msgs)
for tp, msgs in (await self.consumer.getmany(max_records=2)).items()
Expand Down Expand Up @@ -180,6 +187,7 @@ async def test_getmany_specific_poll_without_commit(self):
topic=self.test_topic, partition=1, key="test2", value="test2"
)
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

target = TopicPartition(self.test_topic, 0)

Expand Down Expand Up @@ -207,6 +215,7 @@ async def test_getmany_with_commit(self):
topic=self.test_topic, partition=2, key="test2", value="test2"
)
self.consumer.subscribe(topics=[self.test_topic])
await self.consumer.start()

# Order unknown, though we can check the counts eagerly
messages = {
Expand Down Expand Up @@ -319,8 +328,18 @@ async def test_unsubscribe(self):
self.kafka.create_partition(topic=self.test_topic, partitions=10)

topics = [self.test_topic]

self.consumer.subscribe(topics=topics)

self.assertEqual(self.consumer.subscribed_topic, topics)
self.consumer.unsubscribe()
self.assertEqual(self.consumer.subscribed_topic, [])

async def test_consumer_is_stopped(self):
self.kafka.create_partition(topic=self.test_topic, partitions=10)

topics = [self.test_topic]

self.consumer.subscribe(topics=topics)
with self.assertRaises(ConsumerStoppedError):
await self.consumer.getone()
4 changes: 4 additions & 0 deletions tests/test_aiokafka/test_async_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ async def _create_fake_topics(self):

@aproduce(topic="test", key="test_key", value="test_value", partition=4)
async def test_produce_decorator(self):
await self.consumer.start()

# subscribe to topic and get message
self.consumer.subscribe(topics=["test"])
message = await self.consumer.getone()
Expand All @@ -62,6 +64,7 @@ async def test_produce_decorator(self):
@aproduce(topic="test", key="test_key", value="test_value", partition=4)
@aproduce(topic="test", key="test_key1", value="test_value1", partition=0)
async def test_produce_twice(self):
await self.consumer.start()
# subscribe to topic and get message
self.consumer.subscribe(topics=["test"])

Expand Down Expand Up @@ -90,6 +93,7 @@ async def test_produce_twice(self):
@asetup_kafka(topics=[{"topic": "test_topic", "partition": 16}])
@aproduce(topic="test_topic", partition=5, key="test_", value="test_value1")
async def test_produce_with_kafka_setup_decorator(self):
await self.consumer.start()
# subscribe to topic and get message
self.consumer.subscribe(topics=["test_topic"])

Expand Down

0 comments on commit 7b34aa8

Please sign in to comment.