Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 150 additions & 9 deletions quixstreams/state/recovery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import time
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

from confluent_kafka import OFFSET_BEGINNING
from confluent_kafka import TopicPartition as ConfluentPartition
Expand Down Expand Up @@ -62,6 +62,8 @@ def __init__(
self._committed_offsets = committed_offsets
self._recovery_consume_position: Optional[int] = None
self._initial_offset: Optional[int] = None
self._invalid_offset_count = 0 # Track consecutive invalid offset attempts
self._last_valid_position_time: Optional[float] = None

def __repr__(self):
return (
Expand Down Expand Up @@ -128,6 +130,35 @@ def recovery_consume_position(self) -> Optional[int]:
def had_recovery_changes(self) -> bool:
return self._initial_offset != self.offset

def increment_invalid_offset_count(self) -> int:
"""
Increment the counter for consecutive invalid offset attempts.
Returns the new count.
"""
self._invalid_offset_count += 1
return self._invalid_offset_count

def reset_invalid_offset_count(self):
"""
Reset the invalid offset counter when a valid position is obtained.
"""
self._invalid_offset_count = 0
self._last_valid_position_time = time.monotonic()

@property
def invalid_offset_count(self) -> int:
"""
Get the number of consecutive invalid offset attempts.
"""
return self._invalid_offset_count

@property
def last_valid_position_time(self) -> Optional[float]:
"""
Get the time when a valid position was last obtained.
"""
return self._last_valid_position_time

def recover_from_changelog_message(
self, changelog_message: SuccessfulConfluentKafkaMessageProto
):
Expand Down Expand Up @@ -315,12 +346,18 @@ class RecoveryManager:
Recovery is attempted from the `Application` after any new partition assignment.
"""

# Maximum number of consecutive invalid offset attempts before failing loudly
# At 10-second progress logging intervals, 60 attempts = ~10 minutes
MAX_INVALID_OFFSET_ATTEMPTS = 60

def __init__(self, consumer: BaseConsumer, topic_manager: TopicManager):
self._running = False
self._consumer = consumer
self._topic_manager = topic_manager
self._recovery_partitions: Dict[int, Dict[str, RecoveryPartition]] = {}
self._last_progress_logged_time = time.monotonic()
# Cache position results to avoid double calls in same iteration
self._position_cache: Dict[str, Tuple[float, ConfluentPartition]] = {}

@property
def partitions(self) -> Dict[int, Dict[str, RecoveryPartition]]:
Expand Down Expand Up @@ -515,6 +552,9 @@ def _revoke_recovery_partitions(self, recovery_partitions: List[RecoveryPartitio
)
for rp in recovery_partitions:
del self._recovery_partitions[rp.partition_num][rp.changelog_name]
# Clean up position cache for revoked partition
cache_key = f"{rp.changelog_name}:{rp.partition_num}"
self._position_cache.pop(cache_key, None)
for partition_num in partition_nums:
if not self._recovery_partitions[partition_num]:
del self._recovery_partitions[partition_num]
Expand All @@ -536,6 +576,14 @@ def _update_recovery_status(self):
rp_revokes = []
for rp in dict_values(self._recovery_partitions):
position = self._get_changelog_offset(rp)
if position is None:
# Skip status update if position is not yet valid (e.g., during rebalancing)
# Will retry on next poll cycle
logger.debug(
f"Skipping recovery status update for {rp}: position not available"
)
continue

rp.set_recovery_consume_position(position)
if rp.finished_recovery_check:
rp_revokes.append(rp)
Expand All @@ -561,25 +609,118 @@ def _recovery_loop(self) -> None:
rp = self._recovery_partitions[msg.partition()][msg.topic()]
rp.recover_from_changelog_message(changelog_message=msg)

def _get_position_with_cache(self, rp: RecoveryPartition) -> ConfluentPartition:
"""
Get the consumer position for a RecoveryPartition, using cache to avoid
multiple calls in the same iteration.

:param rp: RecoveryPartition to get position for
:return: ConfluentPartition with offset and error information
"""
cache_key = f"{rp.changelog_name}:{rp.partition_num}"
current_time = time.monotonic()

# Check if we have a fresh cached value (within last second)
if cache_key in self._position_cache:
cached_time, cached_position = self._position_cache[cache_key]
if current_time - cached_time < 1.0:
return cached_position

# Query position and cache it
position_tp = self._consumer.position(
[ConfluentPartition(rp.changelog_name, rp.partition_num)]
)[0]
self._position_cache[cache_key] = (current_time, position_tp)
return position_tp

def _log_recovery_progress(self) -> None:
"""
Periodically log the recovery progress of all RecoveryPartitions.
"""
if self._last_progress_logged_time < time.monotonic() - 10:
for rp in dict_values(self._recovery_partitions):
last_consumed_offset = self._get_changelog_offset(rp) - 1
logger.info(
f"Recovery progress for {rp}: {last_consumed_offset} / {rp.changelog_highwater}"
)
# Use cached position to avoid redundant network calls
position_tp = self._get_position_with_cache(rp)

if position_tp.error:
count = rp.invalid_offset_count
log_level = logger.warning if count > 30 else logger.info
log_level(
f"Recovery progress for {rp}: position unavailable "
f"(error: {position_tp.error}, attempts: {count})"
)
elif position_tp.offset < 0:
count = rp.invalid_offset_count
log_level = logger.warning if count > 30 else logger.info
log_level(
f"Recovery progress for {rp}: position not yet established "
f"(offset: {position_tp.offset}, attempts: {count})"
)
else:
last_consumed_offset = position_tp.offset - 1
logger.info(
f"Recovery progress for {rp}: {last_consumed_offset} / {rp.changelog_highwater}"
)
self._last_progress_logged_time = time.monotonic()

def _get_changelog_offset(self, rp: RecoveryPartition) -> int:
def _get_changelog_offset(self, rp: RecoveryPartition) -> Optional[int]:
"""
Get the current offset of the changelog partition.

Returns None if the position is not yet established (e.g., during rebalancing)
or if there's an error querying the position.

Tracks consecutive invalid offset attempts and raises an exception if the
threshold is exceeded.

:return: The current offset, or None if position is invalid/unavailable
:raises RuntimeError: If position remains invalid beyond MAX_INVALID_OFFSET_ATTEMPTS
"""
return self._consumer.position(
[ConfluentPartition(rp.changelog_name, rp.partition_num)]
)[0].offset
# Use cached position to avoid redundant network calls
position_tp = self._get_position_with_cache(rp)

# Check for Kafka errors (e.g., during rebalancing)
if position_tp.error:
count = rp.increment_invalid_offset_count()
logger.debug(
f"Cannot get position for {rp} due to Kafka error: {position_tp.error}. "
f"This is expected during rebalancing (attempt {count}/{self.MAX_INVALID_OFFSET_ATTEMPTS})."
)
self._check_invalid_offset_threshold(rp, f"error: {position_tp.error}")
return None

# Check for special Kafka offset values (OFFSET_INVALID=-1001, OFFSET_STORED=-1000, etc.)
offset = position_tp.offset
if offset < 0:
count = rp.increment_invalid_offset_count()
logger.debug(
f"Position not yet established for {rp}: offset={offset}. "
f"This is expected during rebalancing (attempt {count}/{self.MAX_INVALID_OFFSET_ATTEMPTS})."
)
self._check_invalid_offset_threshold(rp, f"offset={offset}")
return None

# Valid offset obtained - reset the counter
rp.reset_invalid_offset_count()
return offset

def _check_invalid_offset_threshold(self, rp: RecoveryPartition, reason: str):
"""
Check if the invalid offset count exceeds the threshold and fail loudly if so.

:param rp: RecoveryPartition being checked
:param reason: Description of why the offset is invalid
:raises RuntimeError: If threshold is exceeded
"""
if rp.invalid_offset_count > self.MAX_INVALID_OFFSET_ATTEMPTS:
error_msg = (
f"Recovery stuck for {rp}: position has been invalid for "
f"{rp.invalid_offset_count} consecutive attempts ({reason}). "
f"This indicates a serious issue with the Kafka consumer or broker. "
f"Last valid position was at {rp.last_valid_position_time or 'never'}."
)
logger.error(error_msg)
raise RuntimeError(error_msg)

def stop_recovery(self):
self._running = False
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import MagicMock, patch

import pytest
from confluent_kafka import TopicPartition
from confluent_kafka import OFFSET_INVALID, TopicPartition
from confluent_kafka import TopicPartition as ConfluentPartition

from quixstreams.kafka import Consumer
Expand Down Expand Up @@ -350,6 +350,115 @@ def test_do_recovery_no_partitions_assigned(self, recovery_manager_factory):
# Check that consumer.poll() is not called
assert not consumer.poll.called

def test_do_recovery_handles_invalid_offset_during_rebalance(
self,
recovery_manager_factory,
topic_manager_factory,
):
"""
Test that RecoveryManager handles OFFSET_INVALID gracefully when
consumer.position() returns invalid offset during rebalancing.

This reproduces GitHub issue #1067 where recovery gets stuck in infinite
loop when partition stays assigned through rebalance but position becomes
temporarily invalid.
"""
topic_name = str(uuid.uuid4())
store_name = "default"
lowwater, highwater = 0, 10

# Setup topics
topic_manager = topic_manager_factory()
data_topic = topic_manager.topic(topic_name)
changelog_topic = topic_manager.changelog_topic(
stream_id=topic_name,
store_name=store_name,
config=data_topic.broker_config,
)

data_tp = TopicPartition(topic=data_topic.name, partition=0)
changelog_tp = TopicPartition(topic=changelog_topic.name, partition=0)
assignment = [data_tp, changelog_tp]

# Create changelog message for recovery
# Message at offset (highwater - 1) means after consuming it,
# position will be at highwater, completing recovery
changelog_message = ConfluentKafkaMessageStub(
topic=changelog_topic.name,
partition=0,
offset=highwater - 1,
key=b"key",
value=b"value",
headers=[(CHANGELOG_CF_MESSAGE_HEADER, b"default")],
)

# Create mocked consumer
consumer = MagicMock(spec_set=Consumer)
consumer.assignment.return_value = assignment

# Simulate rebalancing scenario:
# 1. First poll returns None → OFFSET_INVALID detected and skipped
# 2. Second poll returns message → processes it
# 3. Third poll returns None → position check shows recovery complete
consumer.poll.side_effect = [None, changelog_message, None]

# Simulate consumer.position() behavior during rebalance:
# 1. First call returns OFFSET_INVALID (mid-rebalance - gets skipped)
# 2. Subsequent calls return highwater (recovery complete after message)
position_call_count = 0

def position_side_effect(partitions):
nonlocal position_call_count
position_call_count += 1
if position_call_count == 1:
# Mid-rebalance: return OFFSET_INVALID (will be skipped by fix)
return [ConfluentPartition(changelog_topic.name, 0, OFFSET_INVALID)]
else:
# After OFFSET_INVALID resolved, position is at highwater
return [ConfluentPartition(changelog_topic.name, 0, highwater)]

consumer.position.side_effect = position_side_effect

# Mock store partition
store_partition = MagicMock(spec_set=StorePartition)
# Stored offset is one before the message we'll recover
store_partition.get_changelog_offset.return_value = highwater - 2

# Setup recovery manager
recovery_manager = recovery_manager_factory(
consumer=consumer, topic_manager=topic_manager
)

consumer.get_watermark_offsets.return_value = (lowwater, highwater)
recovery_manager.assign_partition(
topic=topic_name,
partition=0,
committed_offsets={topic_name: -1001},
store_partitions={store_name: store_partition},
)

# Trigger recovery - should complete successfully despite OFFSET_INVALID
recovery_manager.do_recovery()

# Verify recovery completed successfully
assert (
not recovery_manager.partitions
), "Recovery should complete and unassign all partitions"
assert consumer.poll.call_count == 3, (
"Should poll three times: "
"1) None (OFFSET_INVALID detected and skipped), "
"2) message consumed, "
"3) None (position==highwater, completes recovery)"
)
assert position_call_count == 2, (
"Should call position twice: "
"first returns OFFSET_INVALID (skipped), "
"second returns highwater (recovery complete)"
)

# Verify the changelog message was processed
store_partition.recover_from_changelog_message.assert_called_once()


@pytest.mark.parametrize("store_type", SUPPORTED_STORES, indirect=True)
class TestRecoveryManagerRecover:
Expand Down
Loading