Skip to content

chore(refactor): Remove Partition.close #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def on_partition_complete_sentinel(

try:
if sentinel.is_successful:
partition.close()
stream = self._stream_name_to_instance[partition.stream_name()]
stream.cursor.close_partition(partition)
except Exception as exception:
self._flag_exception(partition.stream_name(), exception)
yield AirbyteTracedException.from_exception(
Expand Down
11 changes: 0 additions & 11 deletions airbyte_cdk/sources/file_based/stream/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,13 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: "AbstractConcurrentFileBasedCursor",
):
self._stream = stream
self._slice = _slice
self._message_repository = message_repository
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor
self._is_closed = False

def read(self) -> Iterable[Record]:
try:
Expand Down Expand Up @@ -289,13 +286,6 @@ def to_slice(self) -> Optional[Mapping[str, Any]]:
file = self._slice["files"][0]
return {"files": [file]}

def close(self) -> None:
self._cursor.close_partition(self)
self._is_closed = True

def is_closed(self) -> bool:
return self._is_closed

def __hash__(self) -> int:
if self._slice:
# Convert the slice to a string so that it can be hashed
Expand Down Expand Up @@ -352,7 +342,6 @@ def generate(self) -> Iterable[FileBasedStreamPartition]:
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)
)
self._cursor.set_pending_partitions(pending_partitions)
Expand Down
15 changes: 0 additions & 15 deletions airbyte_cdk/sources/streams/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def create_from_stream(
else SyncMode.incremental,
[cursor_field] if cursor_field is not None else None,
state,
cursor,
),
name=stream.name,
namespace=stream.namespace,
Expand Down Expand Up @@ -259,7 +258,6 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: Cursor,
):
"""
:param stream: The stream to delegate to
Expand All @@ -272,8 +270,6 @@ def __init__(
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor
self._is_closed = False

def read(self) -> Iterable[Record]:
"""
Expand Down Expand Up @@ -323,13 +319,6 @@ def __hash__(self) -> int:
def stream_name(self) -> str:
return self._stream.name

def close(self) -> None:
self._cursor.close_partition(self)
self._is_closed = True

def is_closed(self) -> bool:
return self._is_closed

def __repr__(self) -> str:
return f"StreamPartition({self._stream.name}, {self._slice})"

Expand All @@ -349,7 +338,6 @@ def __init__(
sync_mode: SyncMode,
cursor_field: Optional[List[str]],
state: Optional[MutableMapping[str, Any]],
cursor: Cursor,
):
"""
:param stream: The stream to delegate to
Expand All @@ -360,7 +348,6 @@ def __init__(
self._sync_mode = sync_mode
self._cursor_field = cursor_field
self._state = state
self._cursor = cursor

def generate(self) -> Iterable[Partition]:
for s in self._stream.stream_slices(
Expand All @@ -373,7 +360,6 @@ def generate(self) -> Iterable[Partition]:
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)


Expand Down Expand Up @@ -451,7 +437,6 @@ def generate(self) -> Iterable[Partition]:
self._sync_mode,
self._cursor_field,
self._state,
self._cursor,
)


Expand Down
15 changes: 0 additions & 15 deletions airbyte_cdk/sources/streams/concurrent/partitions/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,6 @@ def stream_name(self) -> str:
"""
pass

@abstractmethod
def close(self) -> None:
"""
Closes the partition.
"""
pass

@abstractmethod
def is_closed(self) -> bool:
"""
Returns whether the partition is closed.
:return:
"""
pass

@abstractmethod
def __hash__(self) -> int:
"""
Expand Down
15 changes: 12 additions & 3 deletions unit_tests/sources/file_based/stream/concurrent/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,12 @@ def test_file_based_stream_partition(transformer, expected_records):
cursor_field = None
state = None
partition = FileBasedStreamPartition(
stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR
stream,
_slice,
message_repository,
sync_mode,
cursor_field,
state,
)

a_log_message = AirbyteMessage(
Expand Down Expand Up @@ -168,7 +173,6 @@ def test_file_based_stream_partition_raising_exception(exception_type, expected_
_ANY_SYNC_MODE,
_ANY_CURSOR_FIELD,
_ANY_STATE,
_ANY_CURSOR,
)

stream.read_records.side_effect = Exception()
Expand Down Expand Up @@ -204,7 +208,12 @@ def test_file_based_stream_partition_hash(_slice, expected_hash):
stream = Mock()
stream.name = "stream"
partition = FileBasedStreamPartition(
stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR
stream,
_slice,
Mock(),
_ANY_SYNC_MODE,
_ANY_CURSOR_FIELD,
_ANY_STATE,
)

_hash = partition.__hash__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def test_add_file(
SyncMode.full_refresh,
FileBasedConcurrentCursor.CURSOR_FIELD,
initial_state,
cursor,
)
for uri, timestamp in pending_files
]
Expand Down
9 changes: 3 additions & 6 deletions unit_tests/sources/streams/concurrent/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_stream_partition_generator(sync_mode):
stream.stream_slices.return_value = stream_slices

partition_generator = StreamPartitionGenerator(
stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR
stream, message_repository, _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE
)

partitions = list(partition_generator.generate())
Expand Down Expand Up @@ -115,9 +115,7 @@ def test_stream_partition(transformer, expected_records):
sync_mode = SyncMode.full_refresh
cursor_field = None
state = None
partition = StreamPartition(
stream, _slice, message_repository, sync_mode, cursor_field, state, _ANY_CURSOR
)
partition = StreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state)

a_log_message = AirbyteMessage(
type=MessageType.LOG,
Expand Down Expand Up @@ -162,7 +160,6 @@ def test_stream_partition_raising_exception(exception_type, expected_display_mes
_ANY_SYNC_MODE,
_ANY_CURSOR_FIELD,
_ANY_STATE,
_ANY_CURSOR,
)

stream.read_records.side_effect = Exception()
Expand All @@ -188,7 +185,7 @@ def test_stream_partition_hash(_slice, expected_hash):
stream = Mock()
stream.name = "stream"
partition = StreamPartition(
stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE, _ANY_CURSOR
stream, _slice, Mock(), _ANY_SYNC_MODE, _ANY_CURSOR_FIELD, _ANY_STATE
)

_hash = partition.__hash__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_handle_on_partition_complete_sentinel_with_messages_from_repository(sel
]
assert messages == expected_messages

partition.close.assert_called_once()
self._stream.cursor.close_partition.assert_called_once()

@freezegun.freeze_time("2020-01-01T00:00:00")
def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stream_is_done(
Expand Down Expand Up @@ -298,14 +298,14 @@ def test_handle_on_partition_complete_sentinel_yields_status_message_if_the_stre
)
]
assert messages == expected_messages
self._a_closed_partition.close.assert_called_once()
self._another_stream.cursor.close_partition.assert_called_once()

@freezegun.freeze_time("2020-01-01T00:00:00")
def test_given_exception_on_partition_complete_sentinel_then_yield_error_trace_message_and_stream_is_incomplete(
self,
) -> None:
self._a_closed_partition.stream_name.return_value = self._stream.name
self._a_closed_partition.close.side_effect = ValueError
self._stream.cursor.close_partition.side_effect = ValueError

handler = ConcurrentReadProcessor(
[self._stream],
Expand Down Expand Up @@ -375,7 +375,7 @@ def test_handle_on_partition_complete_sentinel_yields_no_status_message_if_the_s

expected_messages = []
assert messages == expected_messages
partition.close.assert_called_once()
self._stream.cursor.close_partition.assert_called_once()

@freezegun.freeze_time("2020-01-01T00:00:00")
def test_on_record_no_status_message_no_repository_messge(self):
Expand Down Expand Up @@ -733,7 +733,7 @@ def test_given_partition_completion_is_not_success_then_do_not_close_partition(s
)
)

assert self._an_open_partition.close.call_count == 0
assert self._stream.cursor.close_partition.call_count == 0

def test_is_done_is_false_if_there_are_any_instances_to_read_from(self):
stream_instances_to_read_from = [self._stream]
Expand Down
Loading