Skip to content

Commit

Permalink
[per-stream cdk] Allow for reading in per-stream state and passing it…
Browse files Browse the repository at this point in the history
… to connectors (airbytehq#16505)

* update abstract source and connector state manager to update shared and stream state

* clean up some extra comments and extra lines

* a few changes based on PR feedback

* remove separate legacy map from state manager and simplify mapping to descriptor -> state blob

* rename a few paramets and add testing for state update via stream.state override

* replace shared_state processing with an explicit error and fix a few comments and pr feedback

* add some polish and additional test cases

* pr feedback and restructuring parts of the connector state manager initialization

* fix unfinished comment

* Update airbyte-cdk/python/unit_tests/sources/test_abstract_source.py

Co-authored-by: Augustin <augustin.lafanechere@gmail.com>

* use pytest params to annotate tests better

* change to fix changed class name

* Update airbyte-cdk/python/airbyte_cdk/sources/connector_state_manager.py

Co-authored-by: Sherif A. Nada <snadalive@gmail.com>

* a few bits of pr feedback

* pr feedback and cleaning up some comments and variable renames

Co-authored-by: Augustin <augustin.lafanechere@gmail.com>
Co-authored-by: Sherif A. Nada <snadalive@gmail.com>
  • Loading branch information
3 people authored Sep 15, 2022
1 parent a61ac9f commit db56c75
Show file tree
Hide file tree
Showing 4 changed files with 727 additions and 104 deletions.
38 changes: 20 additions & 18 deletions airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,12 @@ def read(
state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]] = None,
) -> Iterator[AirbyteMessage]:
"""Implements the Read operation from the Airbyte Specification. See https://docs.airbyte.io/architecture/airbyte-protocol."""
state_manager = ConnectorStateManager(state=state)
connector_state = state_manager.get_legacy_state()

logger.info(f"Starting syncing {self.name}")
config, internal_config = split_config(config)
# TODO assert all streams exist in the connector
# get the streams once in case the connector needs to make any queries to generate them
stream_instances = {s.name: s for s in self.streams(config)}
state_manager = ConnectorStateManager(stream_instance_map=stream_instances, state=state)
self._stream_to_instance_map = stream_instances
with create_timer(self.name) as timer:
for configured_stream in catalog.streams:
Expand All @@ -116,7 +114,7 @@ def read(
logger=logger,
stream_instance=stream_instance,
configured_stream=configured_stream,
connector_state=connector_state,
state_manager=state_manager,
internal_config=internal_config,
)
except AirbyteTracedException as e:
Expand All @@ -135,15 +133,15 @@ def read(
logger.info(f"Finished syncing {self.name}")

@property
def per_stream_state_enabled(self):
def per_stream_state_enabled(self) -> bool:
return False # While CDK per-stream is in active development we should keep this off

def _read_stream(
self,
logger: logging.Logger,
stream_instance: Stream,
configured_stream: ConfiguredAirbyteStream,
connector_state: MutableMapping[str, Any],
state_manager: ConnectorStateManager,
internal_config: InternalConfig,
) -> Iterator[AirbyteMessage]:
self._apply_log_level_to_stream_logger(logger, stream_instance)
Expand Down Expand Up @@ -172,7 +170,7 @@ def _read_stream(
logger,
stream_instance,
configured_stream,
connector_state,
state_manager,
internal_config,
)
else:
Expand Down Expand Up @@ -206,20 +204,21 @@ def _read_incremental(
logger: logging.Logger,
stream_instance: Stream,
configured_stream: ConfiguredAirbyteStream,
connector_state: MutableMapping[str, Any],
state_manager: ConnectorStateManager,
internal_config: InternalConfig,
) -> Iterator[AirbyteMessage]:
"""Read stream using incremental algorithm
:param logger:
:param stream_instance:
:param configured_stream:
:param connector_state:
:param state_manager:
:param internal_config:
:return:
"""
stream_name = configured_stream.stream.name
stream_state = connector_state.get(stream_name, {})
stream_state = state_manager.get_stream_state(stream_name, stream_instance.namespace)

if stream_state and "state" in dir(stream_instance):
stream_instance.state = stream_state
logger.info(f"Setting state of {stream_name} stream to {stream_state}")
Expand All @@ -233,7 +232,7 @@ def _read_incremental(
total_records_counter = 0
if not slices:
# Safety net to ensure we always emit at least one state message even if there are no slices
checkpoint = self._checkpoint_state(stream_instance, stream_state, connector_state)
checkpoint = self._checkpoint_state(stream_instance, stream_state, state_manager)
yield checkpoint
for _slice in slices:
logger.debug("Processing stream slice", extra={"slice": _slice})
Expand All @@ -248,7 +247,7 @@ def _read_incremental(
stream_state = stream_instance.get_updated_state(stream_state, record_data)
checkpoint_interval = stream_instance.state_checkpoint_interval
if checkpoint_interval and record_counter % checkpoint_interval == 0:
yield self._checkpoint_state(stream_instance, stream_state, connector_state)
yield self._checkpoint_state(stream_instance, stream_state, state_manager)

total_records_counter += 1
# This functionality should ideally live outside of this method
Expand All @@ -258,7 +257,7 @@ def _read_incremental(
# Break from slice loop to save state and exit from _read_incremental function.
break

yield self._checkpoint_state(stream_instance, stream_state, connector_state)
yield self._checkpoint_state(stream_instance, stream_state, state_manager)
if self._limit_reached(internal_config, total_records_counter):
return

Expand All @@ -285,13 +284,16 @@ def _read_full_refresh(
if self._limit_reached(internal_config, total_records_counter):
return

def _checkpoint_state(self, stream, stream_state, connector_state):
@staticmethod
def _checkpoint_state(stream: Stream, stream_state, state_manager):
# First attempt to retrieve the current state using the stream's state property. We receive an AttributeError if the state
# property is not implemented by the stream instance and as a fallback, use the stream_state retrieved from the stream
# instance's deprecated get_updated_state() method.
try:
connector_state[stream.name] = stream.state
state_manager.update_state_for_stream(stream.name, stream.namespace, stream.state)
except AttributeError:
connector_state[stream.name] = stream_state

return AirbyteMessage(type=MessageType.STATE, state=AirbyteStateMessage(data=connector_state))
state_manager.update_state_for_stream(stream.name, stream.namespace, stream_state)
return AirbyteMessage(type=MessageType.STATE, state=AirbyteStateMessage(data=state_manager.get_legacy_state()))

@lru_cache(maxsize=None)
def _get_stream_transformer_and_schema(self, stream_name: str) -> Tuple[TypeTransformer, Mapping[str, Any]]:
Expand Down
166 changes: 140 additions & 26 deletions airbyte-cdk/python/airbyte_cdk/sources/connector_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,22 @@
#

import copy
from typing import Any, List, Mapping, MutableMapping, Union
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union

from airbyte_cdk.models import AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType
from airbyte_cdk.models import AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType, StreamDescriptor
from airbyte_cdk.sources.streams import Stream
from pydantic import Extra


class HashableStreamDescriptor(StreamDescriptor):
"""
Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and
freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests.
"""

class Config:
extra = Extra.allow
frozen = True


class ConnectorStateManager:
Expand All @@ -14,41 +27,142 @@ class ConnectorStateManager:
interface. It also provides methods to extract and update state
"""

# In the immediate, we only persist legacy which will be used during abstract_source.read(). In the subsequent PRs we will
# initialize the ConnectorStateManager according to the new per-stream interface received from the platform
def __init__(self, state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]] = None):
if not state:
self.legacy = {}
elif self.is_migrated_legacy_state(state):
# The legacy state format received from the platform is parsed and stored as a single AirbyteStateMessage when reading
# the file. This is used for input backwards compatibility.
self.legacy = state[0].data
elif isinstance(state, MutableMapping):
# In the event that legacy state comes in as its original JSON object format, no changes to the input need to be made
self.legacy = state
def __init__(self, stream_instance_map: Mapping[str, Stream], state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]] = None):
shared_state, per_stream_states = self._extract_from_state_message(state, stream_instance_map)

# We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are
# designed to checkpoint state independently of one another. API sources should never be emitting a state message where
# shared_state is populated. Rather than define how to handle shared_state without a clear use case, we're opting to throw an
# error instead and if/when we find one, we will then implement processing of the shared_state value.
if shared_state:
raise ValueError(
"Received a GLOBAL AirbyteStateMessages that contain a shared_state. But this library only ever generated per-STREAM "
"STATE messages was not generated this connector. This must be an orchestrator or platform error. GLOBAL state messages "
"with shared_state will not be processed correctly. "
)
self.per_stream_states = per_stream_states

def get_stream_state(self, stream_name: str, namespace: Optional[str]) -> Mapping[str, Any]:
"""
Retrieves the state of a given stream based on its descriptor (name + namespace).
:param stream_name: Name of the stream being fetched
:param namespace: Namespace of the stream being fetched
:return: The combined shared state and per-stream state of a stream
"""
stream_state = self.per_stream_states.get(HashableStreamDescriptor(name=stream_name, namespace=namespace))
if stream_state:
return stream_state.dict()
return {}

def get_legacy_state(self) -> Mapping[str, Any]:
"""
Using the current per-stream state, creates a mapping of all the stream states for the connector being synced
:return: A deep copy of the mapping of stream name to stream state value
"""
return {descriptor.name: state.dict() if state else {} for descriptor, state in self.per_stream_states.items()}

def update_state_for_stream(self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any]):
"""
Overwrites the state blob of a specific stream based on the provided stream name and optional namespace
:param stream_name: The name of the stream whose state is being updated
:param namespace: The namespace of the stream if it exists
:param value: A stream state mapping that is being updated for a stream
:return:
"""
stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
self.per_stream_states[stream_descriptor] = AirbyteStateBlob.parse_obj(value)

@classmethod
def _extract_from_state_message(
cls, state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]], stream_instance_map: Mapping[str, Stream]
) -> Tuple[Optional[AirbyteStateBlob], MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]]]:
"""
Takes an incoming list of state messages or the legacy state format and extracts state attributes according to type
which can then be assigned to the new state manager being instantiated
:param state: The incoming state input
:return: A tuple of shared state and per stream state assembled from the incoming state list
"""
if state is None:
return None, {}

is_legacy = cls._is_legacy_dict_state(state)
is_migrated_legacy = cls._is_migrated_legacy_state(state)
is_global = cls._is_global_state(state)
is_per_stream = cls._is_per_stream_state(state)

# Incoming pure legacy object format
if is_legacy:
streams = cls._create_descriptor_to_stream_state_mapping(state, stream_instance_map)
return None, streams

# When processing incoming state in source.read_state(), legacy state gets deserialized into List[AirbyteStateMessage]
# which can be translated into independent per-stream state values
if is_migrated_legacy:
streams = cls._create_descriptor_to_stream_state_mapping(state[0].data, stream_instance_map)
return None, streams

if is_global:
global_state = state[0].global_
shared_state = copy.deepcopy(global_state.shared_state, {})
streams = {
HashableStreamDescriptor(
name=per_stream_state.stream_descriptor.name, namespace=per_stream_state.stream_descriptor.namespace
): per_stream_state.stream_state
for per_stream_state in global_state.stream_states
}
return shared_state, streams

if is_per_stream:
streams = {
HashableStreamDescriptor(
name=per_stream_state.stream.stream_descriptor.name, namespace=per_stream_state.stream.stream_descriptor.namespace
): per_stream_state.stream.stream_state
for per_stream_state in state
if per_stream_state.type == AirbyteStateType.STREAM and hasattr(per_stream_state, "stream")
}
return None, streams
else:
raise ValueError("Input state should come in the form of list of Airbyte state messages or a mapping of states")

def get_stream_state(self, namespace: str, stream_name: str) -> AirbyteStateBlob:
# todo implement in upcoming PRs
pass

def get_legacy_state(self) -> MutableMapping[str, Any]:
@staticmethod
def _create_descriptor_to_stream_state_mapping(
state: MutableMapping[str, Any], stream_to_instance_map: Mapping[str, Stream]
) -> MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]]:
"""
Returns a deep copy of the current legacy state dictionary made up of the state of all streams for a connector
:return: A copy of the legacy state
Takes incoming state received in the legacy format and transforms it into a mapping of StreamDescriptor to AirbyteStreamState
:param state: A mapping object representing the complete state of all streams in the legacy format
:param stream_to_instance_map: A mapping of stream name to stream instance used to retrieve a stream's namespace
:return: The mapping of all of a sync's streams to the corresponding stream state
"""
return copy.deepcopy(self.legacy, {})
streams = {}
for stream_name, state_value in state.items():
namespace = stream_to_instance_map[stream_name].namespace if stream_name in stream_to_instance_map else None
stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
streams[stream_descriptor] = AirbyteStateBlob.parse_obj(state_value or {})
return streams

def update_state_for_stream(self, namespace: str, stream_name: str, value: Mapping[str, Any]):
# todo implement in upcoming PRs
pass
@staticmethod
def _is_legacy_dict_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]):
return isinstance(state, dict)

@staticmethod
def is_migrated_legacy_state(state: List[AirbyteStateMessage]) -> bool:
def _is_migrated_legacy_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
return (
isinstance(state, List)
and len(state) == 1
and isinstance(state[0], AirbyteStateMessage)
and state[0].type == AirbyteStateType.LEGACY
)

@staticmethod
def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
return (
isinstance(state, List)
and len(state) == 1
and isinstance(state[0], AirbyteStateMessage)
and state[0].type == AirbyteStateType.GLOBAL
)

@staticmethod
def _is_per_stream_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
return isinstance(state, List)
Loading

0 comments on commit db56c75

Please sign in to comment.