Skip to content

Commit

Permalink
Added getter and setter for consumer and producer.
Browse files Browse the repository at this point in the history
  • Loading branch information
Menziess committed May 21, 2024
1 parent 2142787 commit 6b58155
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 29 deletions.
71 changes: 49 additions & 22 deletions snapstream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,27 +264,50 @@ def create_topic(self, *args, **kwargs) -> None:
logger.error(e)
raise

@property
def consumer(self) -> Optional[Consumer]:
"""Get underlying consumer object."""
if not self._consumer:
self._consumer = Consumer(self.conf, logger=logger)

def on_assign(c, ps):
for p in ps:
if self.starting_offset is not None:
p.offset = self.starting_offset
c.assign(ps)

logger.debug(f'Subscribing to topic: {self.name}.')
self._consumer.subscribe([self.name], on_assign=on_assign)

return self._consumer

@consumer.deleter
def consumer(self):
self._consumer = None

@property
def producer(self) -> Producer:
"""Get underlying producer object."""
if not self._producer:
self._producer = Producer(self.conf, logger=logger)
return self._producer

@producer.deleter
def producer(self):
self._producer = None

@contextmanager
def _get_consumer(self) -> Iterator[Iterable[Any]]:
def _get_iterable(self) -> Iterator[Iterable[Any]]:
"""Yield an iterable to consume from kafka."""
self._consumer = Consumer(self.conf, logger=logger)
commit_each_message = pipe(
self.conf.get('enable.auto.commit'),
str,
str.lower
) == 'false' and self.commit_each_message

def consume():
def on_assign(c, ps):
for p in ps:
if self.starting_offset is not None:
p.offset = self.starting_offset
c.assign(ps)

logger.debug(f'Subscribing to topic: {self.name}.')
cast(Consumer, self._consumer).subscribe([self.name], on_assign=on_assign)
logger.debug(f'Consuming from topic: {self.name}.')
yield from self.poller(self._consumer, self.poll_timeout, self.codec,
yield from self.poller(self.consumer, self.poll_timeout, self.codec,
self.raise_error, commit_each_message)
yield consume()
leave_msg = (
Expand All @@ -293,26 +316,28 @@ def on_assign(c, ps):
else 'Committing offsets and leaving group'
)
logger.debug(f'{leave_msg}, flush_timeout={self.flush_timeout}.')
self._consumer.close()
if self._consumer:
cast(Consumer, self.consumer).close()
del self.consumer

@contextmanager
def _get_producer(self) -> Iterator[Callable[[Any, Any], None]]:
def _get_callable(self) -> Iterator[Callable[[Any, Any], None]]:
"""Yield kafka produce method."""
self._producer = self._producer or Producer(self.conf, logger=logger)
yield self.pusher(self._producer, self.name, self.poll_timeout, self.codec, self.dry)
yield self.pusher(self.producer, self.name, self.poll_timeout, self.codec, self.dry)
logger.debug(f'Flushing messages to kafka, flush_timeout={self.flush_timeout}.')
self._producer.flush(self.flush_timeout)
if self._producer:
self.producer.flush(self.flush_timeout)

def __iter__(self) -> Iterator[Any]:
"""Consume from topic."""
c = self._get_consumer()
c = self._get_iterable()
with c as consumer:
for msg in consumer:
yield msg

def __next__(self) -> Any:
"""Consume next message from topic."""
c = self._get_consumer()
c = self._get_iterable()
with c as consumer:
for msg in consumer:
return msg
Expand All @@ -330,21 +355,23 @@ def __getitem__(self, i) -> Any:
i.step,
i.stop
)
c = self._get_consumer()
c = self._get_iterable()
with c as consumer:
for msg in consumer:
if start and start > msg.offset():
continue
if stop and msg.offset() >= stop:
if stop and msg.offset() > stop:
return
if step and (msg.offset() - max(0, start)) % step != 0:
if step and (msg.offset() - max(0, start or 0)) % step != 0:
continue
yield msg
if stop and msg.offset() >= stop:
return

def __call__(self, val, key=None, *args, **kwargs) -> None:
"""Produce to topic."""
if not self._producer_callable:
self._producer_ctx_mgr = self._get_producer()
self._producer_ctx_mgr = self._get_callable()
self._producer_callable = self._producer_ctx_mgr.__enter__()
self._producer_callable(key, val, *args, **kwargs)

Expand Down
39 changes: 36 additions & 3 deletions tests/integration/test_kafka_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import logging

import pytest
from toolz import last

from snapstream import Topic
from snapstream.core import READ_FROM_START


def test_produce_no_kafka(caplog):
Expand Down Expand Up @@ -48,13 +50,44 @@ def test_consume_no_kafka(caplog, timeout):

def test_produce_consume(kafka):
"""Should be able to exchange messages with kafka."""
t = Topic('test', {
t = Topic('test_produce_consume', {
'bootstrap.servers': kafka,
'auto.offset.reset': 'earliest',
'group.instance.id': 'test',
'group.id': 'test',
'group.instance.id': 'test_produce_consume',
'group.id': 'test_produce_consume',
})

t('test')

assert next(t[0]).value() == b'test'

# Close consumer before KafkaContainer goes down
del t


def test_slice_dice(kafka):
"""Should be able to exchange messages with kafka."""
t = Topic('test_slice_dice', {
'bootstrap.servers': kafka,
'auto.offset.reset': 'earliest',
'group.instance.id': 'test_slice_dice',
'group.id': 'test_slice_dice',
}, offset=READ_FROM_START)

for x in range(3):
t(f'test{x}')

# Consume Last in slice, should close consumer
assert last(t[:2]).value() == b'test2'

# Consume first
assert next(t[0]).value() == b'test0'

# Continue after first
for i, msg in enumerate(t):
assert msg.value() == f'test{i + 1}'.encode()
if i == 1:
break

# Close consumer before KafkaContainer goes down
del t
8 changes: 4 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ class MyFailingTopic(ITopic):
MyFailingTopic() # type: ignore


def test_get_consumer():
def test_get_iterable():
"""Should return an interable."""
t = Topic('test', {'group.id': 'test'}, poll_timeout=0)
with t._get_consumer() as c:
with t._get_iterable() as c:
assert isinstance(c, Iterable)


def test_get_producer(mocker):
def test_get_callable(mocker):
"""Should return a callable."""
t = Topic('test', {}, flush_timeout=0)
with t._get_producer() as p:
with t._get_callable() as p:
assert isinstance(p, Callable)
p('test', 'test')

Expand Down

0 comments on commit 6b58155

Please sign in to comment.