Skip to content

fix(concurrent-perpartition-cursor): Fix memory issues #568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import OrderedDict
from copy import deepcopy
from datetime import timedelta
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional

from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import (
Expand Down Expand Up @@ -66,8 +66,8 @@ class ConcurrentPerPartitionCursor(Cursor):
_GLOBAL_STATE_KEY = "state"
_PERPARTITION_STATE_KEY = "states"
_IS_PARTITION_DUPLICATION_LOGGED = False
_KEY = 0
_VALUE = 1
_PARENT_STATE = 0
_GENERATION_SEQUENCE = 1

def __init__(
self,
Expand Down Expand Up @@ -99,19 +99,29 @@ def __init__(
self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict()

# Parent-state tracking: store each partition’s parent state in creation order
self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict()
self._partition_parent_state_map: OrderedDict[str, tuple[Mapping[str, Any], int]] = (
OrderedDict()
)
self._parent_state: Optional[StreamState] = None

# Tracks when the last slice for partition is emitted
self._partitions_done_generating_stream_slices: set[str] = set()
# Used to track the index of partitions that are not closed yet
self._processing_partitions_indexes: List[int] = list()
self._generated_partitions_count: int = 0
# Dictionary to map partition keys to their index
self._partition_key_to_index: dict[str, int] = {}

self._finished_partitions: set[str] = set()
self._lock = threading.Lock()
self._timer = Timer()
self._new_global_cursor: Optional[StreamState] = None
self._lookback_window: int = 0
self._parent_state: Optional[StreamState] = None
self._new_global_cursor: Optional[StreamState] = None
self._number_of_partitions: int = 0
self._use_global_cursor: bool = use_global_cursor
self._partition_serializer = PerPartitionKeySerializer()

# Track the last time a state message was emitted
self._last_emission_time: float = 0.0
self._timer = Timer()

self._set_initial_state(stream_state)

Expand Down Expand Up @@ -157,60 +167,37 @@ def close_partition(self, partition: Partition) -> None:
self._cursor_per_partition[partition_key].close_partition(partition=partition)
cursor = self._cursor_per_partition[partition_key]
if (
partition_key in self._finished_partitions
partition_key in self._partitions_done_generating_stream_slices
and self._semaphore_per_partition[partition_key]._value == 0
):
self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])

# Clean up the partition if it is fully processed
self._cleanup_if_done(partition_key)

self._check_and_update_parent_state()

self._emit_state_message()

def _check_and_update_parent_state(self) -> None:
"""
Pop the leftmost partition state from _partition_parent_state_map only if
*all partitions* up to (and including) that partition key in _semaphore_per_partition
are fully finished (i.e. in _finished_partitions and semaphore._value == 0).
Additionally, delete finished semaphores with a value of 0 to free up memory,
as they are only needed to track errors and completion status.
"""
last_closed_state = None

while self._partition_parent_state_map:
# Look at the earliest partition key in creation order
earliest_key = next(iter(self._partition_parent_state_map))

# Verify ALL partitions from the left up to earliest_key are finished
all_left_finished = True
for p_key, sem in list(
self._semaphore_per_partition.items()
): # Use list to allow modification during iteration
# If any earlier partition is still not finished, we must stop
if p_key not in self._finished_partitions or sem._value != 0:
all_left_finished = False
break
# Once we've reached earliest_key in the semaphore order, we can stop checking
if p_key == earliest_key:
break

# If the partitions up to earliest_key are not all finished, break the while-loop
if not all_left_finished:
break
earliest_key, (candidate_state, candidate_seq) = next(
iter(self._partition_parent_state_map.items())
)

# Pop the leftmost entry from parent-state map
_, closed_parent_state = self._partition_parent_state_map.popitem(last=False)
last_closed_state = closed_parent_state
# if any partition that started <= candidate_seq is still open, we must wait
if (
self._processing_partitions_indexes
and self._processing_partitions_indexes[0] <= candidate_seq
):
break

# Clean up finished semaphores with value 0 up to and including earliest_key
for p_key in list(self._semaphore_per_partition.keys()):
sem = self._semaphore_per_partition[p_key]
if p_key in self._finished_partitions and sem._value == 0:
del self._semaphore_per_partition[p_key]
logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0")
if p_key == earliest_key:
break
# safe to pop
self._partition_parent_state_map.popitem(last=False)
last_closed_state = candidate_state

# Update _parent_state if we popped at least one partition
if last_closed_state is not None:
self._parent_state = last_closed_state

Expand Down Expand Up @@ -289,26 +276,32 @@ def _generate_slices_from_partition(
if not self._IS_PARTITION_DUPLICATION_LOGGED:
logger.warning(f"Partition duplication detected for stream {self._stream_name}")
self._IS_PARTITION_DUPLICATION_LOGGED = True
return
else:
self._semaphore_per_partition[partition_key] = threading.Semaphore(0)

with self._lock:
seq = self._generated_partitions_count
self._generated_partitions_count += 1
self._processing_partitions_indexes.append(seq)
self._partition_key_to_index[partition_key] = seq

if (
len(self._partition_parent_state_map) == 0
or self._partition_parent_state_map[
next(reversed(self._partition_parent_state_map))
]
][self._PARENT_STATE]
!= parent_state
):
self._partition_parent_state_map[partition_key] = deepcopy(parent_state)
self._partition_parent_state_map[partition_key] = (deepcopy(parent_state), seq)

for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
cursor.stream_slices(),
lambda: None,
):
self._semaphore_per_partition[partition_key].release()
if is_last_slice:
self._finished_partitions.add(partition_key)
self._partitions_done_generating_stream_slices.add(partition_key)
yield StreamSlice(
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
)
Expand Down Expand Up @@ -338,14 +331,11 @@ def _ensure_partition_limit(self) -> None:
while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1:
# Try removing finished partitions first
for partition_key in list(self._cursor_per_partition.keys()):
if partition_key in self._finished_partitions and (
partition_key not in self._semaphore_per_partition
or self._semaphore_per_partition[partition_key]._value == 0
):
if partition_key not in self._partition_key_to_index:
oldest_partition = self._cursor_per_partition.pop(
partition_key
) # Remove the oldest partition
logger.warning(
logger.debug(
f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}."
)
break
Expand Down Expand Up @@ -474,6 +464,25 @@ def _update_global_cursor(self, value: Any) -> None:
):
self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)}

def _cleanup_if_done(self, partition_key: str) -> None:
"""
Free every in-memory structure that belonged to a completed partition:
cursor, semaphore, flag inside `_finished_partitions`
"""
if not (
partition_key in self._partitions_done_generating_stream_slices
and self._semaphore_per_partition[partition_key]._value == 0
):
return

self._semaphore_per_partition.pop(partition_key, None)
self._partitions_done_generating_stream_slices.discard(partition_key)

seq = self._partition_key_to_index.pop(partition_key)
self._processing_partitions_indexes.remove(seq)

logger.debug(f"Partition {partition_key} fully processed and cleaned up.")

def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
return self._partition_serializer.to_partition_key(partition)

Expand Down
Loading
Loading