From ef3c6e10b596a06a6a8fa27ceaf1992ced359ca8 Mon Sep 17 00:00:00 2001 From: ekerstens <49325583+ekerstens@users.noreply.github.com> Date: Tue, 14 Dec 2021 14:48:23 -0800 Subject: [PATCH] set active_tps in place (#239) * set active_tps in place * Cancel getmany during rebalance * Fix nonetype error * Check for nonetype * Fix wait_results coro check * change == to is Co-authored-by: Eric Kerstens --- faust/transport/consumer.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/faust/transport/consumer.py b/faust/transport/consumer.py index 401e35bef..c67244139 100644 --- a/faust/transport/consumer.py +++ b/faust/transport/consumer.py @@ -429,7 +429,7 @@ class Consumer(Service, ConsumerT): _commit_every: Optional[int] _n_acked: int = 0 - _active_partitions: Optional[Set[TP]] + _active_partitions: Set[TP] _paused_partitions: Set[TP] _buffered_partitions: Set[TP] @@ -495,7 +495,7 @@ def on_init_dependencies(self) -> Iterable[ServiceT]: return [] def _reset_state(self) -> None: - self._active_partitions = None + self._active_partitions = set() self._paused_partitions = set() self._buffered_partitions = set() self.can_resume_flow.clear() @@ -516,9 +516,12 @@ def _get_active_partitions(self) -> Set[TP]: return tps def _set_active_tps(self, tps: Set[TP]) -> Set[TP]: - xtps = self._active_partitions = ensure_TPset(tps) # copy - xtps.difference_update(self._paused_partitions) - return xtps + if self._active_partitions is None: + self._active_partitions = set() + self._active_partitions.clear() + self._active_partitions.update(ensure_TPset(tps)) + self._active_partitions.difference_update(self._paused_partitions) + return self._active_partitions def on_buffer_full(self, tp: TP) -> None: # do not remove the partition when in recovery @@ -730,6 +733,13 @@ async def getmany(self, timeout: float) -> AsyncIterator[Tuple[TP, Message]]: # convert timestamp to seconds from int milliseconds. yield tp, to_message(tp, record) + async def _wait_suspend(self): + """Wrapper around self.suspend_flow.wait() with no return value. + + This allows for easily + """ + await self.suspend_flow.wait() + async def _wait_next_records( self, timeout: float ) -> Tuple[Optional[RecordMap], Optional[Set[TP]]]: @@ -750,10 +760,18 @@ async def _wait_next_records( # Fetch records only if active partitions to avoid the risk of # fetching all partitions in the beginning when none of the # partitions is paused/resumed. - records = await self._getmany( + _getmany = self._getmany( active_partitions=active_partitions, timeout=timeout, ) + wait_results = await self.wait_first( + _getmany, + self.suspend_flow.wait(), + ) + for coro, result in zip(wait_results.done, wait_results.results): + if coro is _getmany: + records = result + break else: # We should still release to the event loop await self.sleep(1)