diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py index 4babf99e4a0d..e54264fa86c7 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py @@ -6,12 +6,10 @@ from collections import OrderedDict from typing import Any, Callable, Iterable, Mapping, Optional, Union -from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer from airbyte_cdk.sources.types import Record, StreamSlice, StreamState -from airbyte_cdk.utils import AirbyteTracedException class CursorFactory: @@ -48,6 +46,7 @@ class PerPartitionCursor(DeclarativeCursor): _NO_CURSOR_STATE: Mapping[str, Any] = {} _KEY = 0 _VALUE = 1 + _state_to_migrate_from: Mapping[str, Any] = {} def __init__(self, cursor_factory: CursorFactory, partition_router: PartitionRouter): self._cursor_factory = cursor_factory @@ -65,7 +64,8 @@ def stream_slices(self) -> Iterable[StreamSlice]: cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition)) if not cursor: - cursor = self._create_cursor(self._NO_CURSOR_STATE) + partition_state = self._state_to_migrate_from if self._state_to_migrate_from else self._NO_CURSOR_STATE + cursor = self._create_cursor(partition_state) self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor for cursor_slice in cursor.stream_slices(): @@ -113,15 +113,13 @@ def set_initial_state(self, stream_state: StreamState) -> None: return if "states" not in stream_state: - raise AirbyteTracedException( - internal_message=f"Could not sync parse the following state: {stream_state}", - message="The state for is format invalid. Validate that the migration steps included a reset and that it was performed " - "properly. Otherwise, please contact Airbyte support.", - failure_type=FailureType.config_error, - ) + # We assume that `stream_state` is in a global format that can be applied to all partitions. + # Example: {"global_state_format_key": "global_state_format_value"} + self._state_to_migrate_from = stream_state - for state in stream_state["states"]: - self._cursor_per_partition[self._to_partition_key(state["partition"])] = self._create_cursor(state["cursor"]) + else: + for state in stream_state["states"]: + self._cursor_per_partition[self._to_partition_key(state["partition"])] = self._create_cursor(state["cursor"]) # Set parent state for partition routers based on parent streams self._partition_router.set_initial_state(stream_state) diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py b/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py index b2c8d5faf46d..823405cb5152 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py @@ -6,12 +6,10 @@ from unittest.mock import Mock import pytest -from airbyte_cdk.models import FailureType from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, PerPartitionKeySerializer, StreamSlice from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter from airbyte_cdk.sources.types import Record -from airbyte_cdk.utils import AirbyteTracedException PARTITION = { "partition_key string": "partition value", @@ -519,10 +517,37 @@ def test_get_stream_state_includes_parent_state(mocked_cursor_factory, mocked_pa assert stream_state == expected_state -def test_given_invalid_state_when_set_initial_state_then_raise_config_error(mocked_cursor_factory, mocked_partition_router) -> None: - cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) - - with pytest.raises(AirbyteTracedException) as exception: - cursor.set_initial_state({"invalid_state": 1}) +def test_per_partition_state_when_set_initial_global_state(mocked_cursor_factory, mocked_partition_router) -> None: + first_partition = {"first_partition_key": "first_partition_value"} + second_partition = {"second_partition_key": "second_partition_value"} + global_state = {"global_state_format_key": "global_state_format_value"} - assert exception.value.failure_type == FailureType.config_error + mocked_partition_router.stream_slices.return_value = [ + StreamSlice(partition=first_partition, cursor_slice={}), + StreamSlice(partition=second_partition, cursor_slice={}), + ] + mocked_cursor_factory.create.side_effect = [ + MockedCursorBuilder().with_stream_state(global_state).build(), + MockedCursorBuilder().with_stream_state(global_state).build(), + ] + cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) + global_state = {"global_state_format_key": "global_state_format_value"} + cursor.set_initial_state(global_state) + assert cursor._state_to_migrate_from == global_state + list(cursor.stream_slices()) + assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_count == 1 + assert cursor._cursor_per_partition['{"first_partition_key":"first_partition_value"}'].set_initial_state.call_args[0] == ( + {"global_state_format_key": "global_state_format_value"}, + ) + assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_count == 1 + assert cursor._cursor_per_partition['{"second_partition_key":"second_partition_value"}'].set_initial_state.call_args[0] == ( + {"global_state_format_key": "global_state_format_value"}, + ) + expected_state = [ + {"cursor": {"global_state_format_key": "global_state_format_value"}, "partition": {"first_partition_key": "first_partition_value"}}, + { + "cursor": {"global_state_format_key": "global_state_format_value"}, + "partition": {"second_partition_key": "second_partition_value"}, + }, + ] + assert cursor.get_stream_state()["states"] == expected_state