Skip to content

Commit

Permalink
set active_tps in place (faust-streaming#239)
Browse files Browse the repository at this point in the history
* set active_tps in place

* Cancel getmany during rebalance

* Fix nonetype error

* Check for nonetype

* Fix wait_results coro check

* change == to is

Co-authored-by: Eric Kerstens <ekerstens@expediagroup.com>
  • Loading branch information
ekerstens and Eric Kerstens authored Dec 14, 2021
1 parent a79563a commit ef3c6e1
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions faust/transport/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ class Consumer(Service, ConsumerT):
_commit_every: Optional[int]
_n_acked: int = 0

_active_partitions: Optional[Set[TP]]
_active_partitions: Set[TP]
_paused_partitions: Set[TP]
_buffered_partitions: Set[TP]

Expand Down Expand Up @@ -495,7 +495,7 @@ def on_init_dependencies(self) -> Iterable[ServiceT]:
return []

def _reset_state(self) -> None:
self._active_partitions = None
self._active_partitions = set()
self._paused_partitions = set()
self._buffered_partitions = set()
self.can_resume_flow.clear()
Expand All @@ -516,9 +516,12 @@ def _get_active_partitions(self) -> Set[TP]:
return tps

def _set_active_tps(self, tps: Set[TP]) -> Set[TP]:
xtps = self._active_partitions = ensure_TPset(tps) # copy
xtps.difference_update(self._paused_partitions)
return xtps
if self._active_partitions is None:
self._active_partitions = set()
self._active_partitions.clear()
self._active_partitions.update(ensure_TPset(tps))
self._active_partitions.difference_update(self._paused_partitions)
return self._active_partitions

def on_buffer_full(self, tp: TP) -> None:
# do not remove the partition when in recovery
Expand Down Expand Up @@ -730,6 +733,13 @@ async def getmany(self, timeout: float) -> AsyncIterator[Tuple[TP, Message]]:
# convert timestamp to seconds from int milliseconds.
yield tp, to_message(tp, record)

async def _wait_suspend(self):
"""Wrapper around self.suspend_flow.wait() with no return value.
This allows for easily
"""
await self.suspend_flow.wait()

async def _wait_next_records(
self, timeout: float
) -> Tuple[Optional[RecordMap], Optional[Set[TP]]]:
Expand All @@ -750,10 +760,18 @@ async def _wait_next_records(
# Fetch records only if active partitions to avoid the risk of
# fetching all partitions in the beginning when none of the
# partitions is paused/resumed.
records = await self._getmany(
_getmany = self._getmany(
active_partitions=active_partitions,
timeout=timeout,
)
wait_results = await self.wait_first(
_getmany,
self.suspend_flow.wait(),
)
for coro, result in zip(wait_results.done, wait_results.results):
if coro is _getmany:
records = result
break
else:
# We should still release to the event loop
await self.sleep(1)
Expand Down

0 comments on commit ef3c6e1

Please sign in to comment.