diff --git a/faust/_cython/streams.pyx b/faust/_cython/streams.pyx index 13847be8f..9043976d0 100644 --- a/faust/_cython/streams.pyx +++ b/faust/_cython/streams.pyx @@ -89,6 +89,8 @@ cdef class StreamIterator: if need_slow_get: channel_value = await self.chan_slow_get() event, value, sensor_state = self._prepare_event(channel_value) + if value is self._skipped_value: + return value, sensor_state try: for processor in self.processors: @@ -158,6 +160,16 @@ cdef class StreamIterator: offset = message.offset consumer = self.consumer + if message.generation_id != self.app.consumer_generation_id: + self.app.log.dev( + "Skipping message %r with generation_id %r because " + "app generation_id is %r", + message, + message.generation_id, + self.app.consumer_generation_id, + ) + return None, self._skipped_value, stream_state + if topic in self.acking_topics and not message.tracked: message.tracked = True self.add_unacked(message) diff --git a/faust/agents/agent.py b/faust/agents/agent.py index 1328e4486..d4b6e9e21 100644 --- a/faust/agents/agent.py +++ b/faust/agents/agent.py @@ -1184,6 +1184,7 @@ def to_message( checksum=b"", serialized_key_size=0, serialized_value_size=0, + generation_id=self.app.consumer_generation_id, ) async def throw(self, exc: BaseException) -> None: diff --git a/faust/channels.py b/faust/channels.py index ed39c90d2..937f6d86b 100644 --- a/faust/channels.py +++ b/faust/channels.py @@ -288,6 +288,7 @@ def as_future_message( # [ask] topic=None, offset=None, + generation_id=self.app.consumer_generation_id, ), ) diff --git a/faust/streams.py b/faust/streams.py index 9ecf2b77f..865888ab5 100644 --- a/faust/streams.py +++ b/faust/streams.py @@ -974,6 +974,16 @@ async def _py_aiter(self) -> AsyncIterator[T_co]: tp = message.tp offset = message.offset + if message.generation_id != self.app.consumer_generation_id: + value = skipped_value + self.log.dev( + "Skipping message %r with generation_id %r because " + "app generation_id is %r", + message, + message.generation_id, + self.app.consumer_generation_id, + ) + break if topic in acking_topics and not message.tracked: message.tracked = True # This inlines Consumer.track_message(message) diff --git a/faust/transport/consumer.py b/faust/transport/consumer.py index 4452d85ff..6d1f12a5b 100644 --- a/faust/transport/consumer.py +++ b/faust/transport/consumer.py @@ -435,6 +435,7 @@ class Consumer(Service, ConsumerT): flow_active: bool = True can_resume_flow: Event + suspend_flow: Event def __init__( self, @@ -475,6 +476,7 @@ def __init__( self._end_offset_monitor_interval = self.commit_interval * 2 self.randomly_assigned_topics = set() self.can_resume_flow = Event() + self.suspend_flow = Event() self._reset_state() super().__init__(loop=loop or self.transport.loop, **kwargs) self.transactions = self.transport.create_transaction_manager( @@ -497,6 +499,7 @@ def _reset_state(self) -> None: self._paused_partitions = set() self._buffered_partitions = set() self.can_resume_flow.clear() + self.suspend_flow.clear() self.flow_active = True self._time_start = monotonic() @@ -573,11 +576,13 @@ def stop_flow(self) -> None: """Block consumer from processing any more messages.""" self.flow_active = False self.can_resume_flow.clear() + self.suspend_flow.set() def resume_flow(self) -> None: """Allow consumer to process messages.""" self.flow_active = True self.can_resume_flow.set() + self.suspend_flow.clear() def pause_partitions(self, tps: Iterable[TP]) -> None: """Pause fetching from partitions.""" @@ -1120,7 +1125,9 @@ async def _drain_messages(self, fetcher: ServiceT) -> None: # pragma: no cover if self._n_acked >= commit_every: self._n_acked = 0 await self.commit() - await callback(message) + await self.wait_first( + callback(message), self.suspend_flow.wait() + ) set_read_offset(tp, offset) else: self.log.dev( diff --git a/faust/transport/drivers/aiokafka.py b/faust/transport/drivers/aiokafka.py index f2f244db9..c15751c55 100644 --- a/faust/transport/drivers/aiokafka.py +++ b/faust/transport/drivers/aiokafka.py @@ -267,6 +267,7 @@ def _to_message(self, tp: TP, record: Any) -> ConsumerMessage: record.serialized_key_size, record.serialized_value_size, tp, + generation_id=self.app.consumer_generation_id, ) async def on_stop(self) -> None: diff --git a/faust/types/tuples.py b/faust/types/tuples.py index 226648737..aabe6c217 100644 --- a/faust/types/tuples.py +++ b/faust/types/tuples.py @@ -69,6 +69,7 @@ class PendingMessage(NamedTuple): callback: Optional[MessageSentCallback] topic: Optional[str] = None offset: Optional[int] = None + generation_id: Optional[int] = None def _PendingMessage_to_Message(p: PendingMessage) -> "Message": @@ -92,6 +93,7 @@ def _PendingMessage_to_Message(p: PendingMessage) -> "Message": value=p.value, checksum=None, tp=tp, + generation_id=p.generation_id, ) @@ -133,6 +135,7 @@ class Message: "tracked", "span", "__weakref__", + "generation_id", ) use_tracking: bool = False @@ -154,6 +157,7 @@ def __init__( time_in: float = None, time_out: float = None, time_total: float = None, + generation_id: int = None, ) -> None: self.topic: str = topic self.partition: int = partition @@ -183,6 +187,12 @@ def __init__( #: still processing. self.time_total: Optional[float] = time_total + # In some edge cases a message can slip through to the stream from before a + # rebalance occured if it gets stuck in the conductor or somewhere else. We + # track the generation_id when the message is fetched so we can discard if + # needed. + self.generation_id: Optional[int] = generation_id + def ack(self, consumer: _ConsumerT, n: int = 1) -> bool: if not self.acked: # if no more references, mark offset as safe-to-commit in diff --git a/tests/functional/agents/helpers.py b/tests/functional/agents/helpers.py index 81aae65a1..1e7ab6efa 100644 --- a/tests/functional/agents/helpers.py +++ b/tests/functional/agents/helpers.py @@ -191,6 +191,7 @@ def Message( value=value, checksum=checksum, tp=tp, + generation_id=self.app.consumer_generation_id, ) def next_offset(self, tp: TP, *, offsets=CURRENT_OFFSETS) -> int: diff --git a/tests/helpers.py b/tests/helpers.py index 747881652..260194bba 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -15,7 +15,8 @@ def message( timestamp=None, headers=None, offset=1, - checksum=None + checksum=None, + generation_id=0, ): return Message( key=key, @@ -27,6 +28,7 @@ def message( timestamp_type=1 if timestamp else 0, headers=headers, checksum=checksum, + generation_id=generation_id, )