Skip to content

Commit

Permalink
feat(airbyte-cdk): add client side incremental sync (#38099)
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Inzhyyants <artem.inzhyyants@gmail.com>
  • Loading branch information
artem1205 authored Jun 3, 2024
1 parent 8412602 commit b9a421b
Show file tree
Hide file tree
Showing 8 changed files with 415 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,10 @@ definitions:
title: Whether the target API is formatted as a data feed
description: A data feed API is an API that does not allow filtering and paginates the content from the most recent to the least recent. Given this, the CDK needs to know when to stop paginating and this field will generate a stop condition for pagination.
type: boolean
is_client_side_incremental:
title: Whether the target API does not support filtering and returns all data (the cursor filters records in the client instead of the API side)
description: If the target API endpoint does not take cursor values to filter records and returns all records anyway, the connector with this cursor will filter out records locally, and only emit new records from the last sync, hence incremental. This means that all records would be read from the API, but only new records will be emitted to the destination.
type: boolean
lookback_window:
title: Lookback Window
description: Time interval before the start_datetime to read data for, e.g. P1M for looking back one month.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import datetime
from dataclasses import InitVar, dataclass
from typing import Any, Iterable, Mapping, Optional

from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor, PerPartitionCursor
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState

Expand Down Expand Up @@ -36,3 +37,80 @@ def filter_records(
for record in records:
if self._filter_interpolator.eval(self.config, record=record, **kwargs):
yield record


class ClientSideIncrementalRecordFilterDecorator(RecordFilter):
"""
Applies a filter to a list of records to exclude those that are older than the stream_state/start_date.
:param DatetimeBasedCursor date_time_based_cursor: Cursor used to extract datetime values
:param PerPartitionCursor per_partition_cursor: Optional Cursor used for mapping cursor value in nested stream_state
"""

def __init__(
self, date_time_based_cursor: DatetimeBasedCursor, per_partition_cursor: Optional[PerPartitionCursor] = None, **kwargs: Any
):
super().__init__(**kwargs)
self._date_time_based_cursor = date_time_based_cursor
self._per_partition_cursor = per_partition_cursor

@property
def _cursor_field(self) -> str:
return self._date_time_based_cursor.cursor_field.eval(self._date_time_based_cursor.config) # type: ignore # eval returns a string in this context

@property
def _start_date_from_config(self) -> datetime.datetime:
return self._date_time_based_cursor._start_datetime.get_datetime(self._date_time_based_cursor.config)

@property
def _end_datetime(self) -> datetime.datetime:
return (
self._date_time_based_cursor._end_datetime.get_datetime(self._date_time_based_cursor.config)
if self._date_time_based_cursor._end_datetime
else datetime.datetime.max
)

def filter_records(
self,
records: Iterable[Mapping[str, Any]],
stream_state: StreamState,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Iterable[Mapping[str, Any]]:
state_value = self._get_state_value(stream_state, stream_slice or StreamSlice(partition={}, cursor_slice={}))
filter_date: datetime.datetime = self._get_filter_date(state_value)
records = (
record
for record in records
if self._end_datetime > self._date_time_based_cursor.parse_date(record[self._cursor_field]) > filter_date
)
if self.condition:
records = super().filter_records(
records=records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
yield from records

def _get_state_value(self, stream_state: StreamState, stream_slice: StreamSlice) -> Optional[str]:
"""
Return cursor_value or None in case it was not found.
Cursor_value may be empty if:
1. It is an initial sync => no stream_state exist at all.
2. In Parent-child stream, and we already make initial sync, so stream_state is present.
During the second read, we receive one extra record from parent and therefore no stream_state for this record will be found.
:param StreamState stream_state: State
:param StreamSlice stream_slice: Current Stream slice
:return Optional[str]: cursor_value in case it was found, otherwise None.
"""
if self._per_partition_cursor:
# self._per_partition_cursor is the same object that DeclarativeStream uses to save/update stream_state
partition_state = self._per_partition_cursor.select_state(stream_slice=stream_slice)
return partition_state.get(self._cursor_field) if partition_state else None
return stream_state.get(self._cursor_field)

def _get_filter_date(self, state_value: Optional[str]) -> datetime.datetime:
start_date_parsed = self._start_date_from_config
if state_value:
return max(start_date_parsed, self._date_time_based_cursor.parse_date(state_value))
else:
return start_date_parsed
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
else datetime.timedelta.max
)
self._cursor_granularity = self._parse_timedelta(self.cursor_granularity)
self._cursor_field = InterpolatedString.create(self.cursor_field, parameters=parameters)
self.cursor_field = InterpolatedString.create(self.cursor_field, parameters=parameters)
self._lookback_window = InterpolatedString.create(self.lookback_window, parameters=parameters) if self.lookback_window else None
self._partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters)
self._partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters)
Expand All @@ -103,7 +103,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.cursor_datetime_formats = [self.datetime_format]

def get_stream_state(self) -> StreamState:
return {self._cursor_field.eval(self.config): self._cursor} if self._cursor else {}
return {self.cursor_field.eval(self.config): self._cursor} if self._cursor else {} # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Expand All @@ -112,7 +112,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
:param stream_state: The state of the stream as returned by get_stream_state
"""
self._cursor = stream_state.get(self._cursor_field.eval(self.config)) if stream_state else None
self._cursor = stream_state.get(self.cursor_field.eval(self.config)) if stream_state else None # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__

def observe(self, stream_slice: StreamSlice, record: Record) -> None:
"""
Expand All @@ -122,7 +122,7 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
:param record: the most recently-read record, which the cursor can use to update the stream state. Outwardly-visible changes to the
stream state may need to be deferred depending on whether the source reliably orders records by the cursor field.
"""
record_cursor_value = record.get(self._cursor_field.eval(self.config))
record_cursor_value = record.get(self.cursor_field.eval(self.config)) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__
# if the current record has no cursor value, we cannot meaningfully update the state based on it, so there is nothing more to do
if not record_cursor_value:
return
Expand Down Expand Up @@ -186,8 +186,8 @@ def _select_best_end_datetime(self) -> datetime.datetime:
return min(self._end_datetime.get_datetime(self.config), now)

def _calculate_cursor_datetime_from_state(self, stream_state: Mapping[str, Any]) -> datetime.datetime:
if self._cursor_field.eval(self.config, stream_state=stream_state) in stream_state:
return self.parse_date(stream_state[self._cursor_field.eval(self.config)])
if self.cursor_field.eval(self.config, stream_state=stream_state) in stream_state: # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__
return self.parse_date(stream_state[self.cursor_field.eval(self.config)]) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__
return datetime.datetime.min.replace(tzinfo=datetime.timezone.utc)

def _format_datetime(self, dt: datetime.datetime) -> str:
Expand Down Expand Up @@ -300,7 +300,7 @@ def _get_request_options(self, option_type: RequestOptionType, stream_slice: Opt
return options

def should_be_synced(self, record: Record) -> bool:
cursor_field = self._cursor_field.eval(self.config)
cursor_field = self.cursor_field.eval(self.config) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__
record_cursor_value = record.get(cursor_field)
if not record_cursor_value:
self._send_log(
Expand All @@ -315,7 +315,7 @@ def should_be_synced(self, record: Record) -> bool:
def _is_within_daterange_boundaries(
self, record: Record, start_datetime_boundary: Union[datetime.datetime, str], end_datetime_boundary: Union[datetime.datetime, str]
) -> bool:
cursor_field = self._cursor_field.eval(self.config)
cursor_field = self.cursor_field.eval(self.config) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__
record_cursor_value = record.get(cursor_field)
if not record_cursor_value:
self._send_log(
Expand All @@ -339,7 +339,7 @@ def _send_log(self, level: Level, message: str) -> None:
)

def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
cursor_field = self._cursor_field.eval(self.config)
cursor_field = self.cursor_field.eval(self.config) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__
first_cursor_value = first.get(cursor_field)
second_cursor_value = second.get(cursor_field)
if first_cursor_value and second_cursor_value:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,11 @@ class DatetimeBasedCursor(BaseModel):
description='A data feed API is an API that does not allow filtering and paginates the content from the most recent to the least recent. Given this, the CDK needs to know when to stop paginating and this field will generate a stop condition for pagination.',
title='Whether the target API is formatted as a data feed',
)
is_client_side_incremental: Optional[bool] = Field(
None,
description='If the target API endpoint does not take cursor values to filter records and returns all records anyway, the connector with this cursor will filter out records locally, and only emit new records from the last sync, hence incremental. This means that all records would be read from the API, but only new records will be emitted to the destination.',
title='Whether the target API does not support filtering and returns all data (the cursor filters records in the client instead of the API side)',
)
lookback_window: Optional[str] = Field(
None,
description='Time interval before the start_datetime to read data for, e.g. P1M for looking back one month.',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import importlib
import inspect
import re
from typing import Any, Callable, List, Mapping, Optional, Type, Union, get_args, get_origin, get_type_hints
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union, get_args, get_origin, get_type_hints

from airbyte_cdk.models import Level
from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator, JwtAuthenticator
Expand All @@ -27,6 +27,7 @@
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.decoders import JsonDecoder
from airbyte_cdk.sources.declarative.extractors import DpathExtractor, RecordFilter, RecordSelector
from airbyte_cdk.sources.declarative.extractors.record_filter import ClientSideIncrementalRecordFilterDecorator
from airbyte_cdk.sources.declarative.extractors.record_selector import SCHEMA_TRANSFORMER_TYPE_MAPPING
from airbyte_cdk.sources.declarative.incremental import (
CursorFactory,
Expand Down Expand Up @@ -558,6 +559,8 @@ def create_datetime_based_cursor(self, model: DatetimeBasedCursorModel, config:
end_datetime: Union[str, MinMaxDatetime, None] = None
if model.is_data_feed and model.end_datetime:
raise ValueError("Data feed does not support end_datetime")
if model.is_data_feed and model.is_client_side_incremental:
raise ValueError("`Client side incremental` cannot be applied with `data feed`. Choose only 1 from them.")
if model.end_datetime:
end_datetime = (
model.end_datetime if isinstance(model.end_datetime, str) else self.create_min_max_datetime(model.end_datetime, config)
Expand Down Expand Up @@ -611,6 +614,18 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi
stop_condition_on_cursor = (
model.incremental_sync and hasattr(model.incremental_sync, "is_data_feed") and model.incremental_sync.is_data_feed
)
client_side_incremental_sync = None
if (
model.incremental_sync
and hasattr(model.incremental_sync, "is_client_side_incremental")
and model.incremental_sync.is_client_side_incremental
):
if combined_slicers and not isinstance(combined_slicers, (DatetimeBasedCursor, PerPartitionCursor)):
raise ValueError("Unsupported Slicer is used. PerPartitionCursor should be used here instead")
client_side_incremental_sync = {
"date_time_based_cursor": self._create_component_from_model(model=model.incremental_sync, config=config),
"per_partition_cursor": combined_slicers if isinstance(combined_slicers, PerPartitionCursor) else None,
}
transformations = []
if model.transformations:
for transformation_model in model.transformations:
Expand All @@ -622,6 +637,7 @@ def create_declarative_stream(self, model: DeclarativeStreamModel, config: Confi
primary_key=primary_key,
stream_slicer=combined_slicers,
stop_condition_on_cursor=stop_condition_on_cursor,
client_side_incremental_sync=client_side_incremental_sync,
transformations=transformations,
)
cursor_field = model.incremental_sync.cursor_field if model.incremental_sync else None
Expand Down Expand Up @@ -982,11 +998,19 @@ def create_record_selector(
config: Config,
*,
transformations: List[RecordTransformation],
client_side_incremental_sync: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> RecordSelector:
assert model.schema_normalization is not None # for mypy
extractor = self._create_component_from_model(model=model.extractor, config=config)
record_filter = self._create_component_from_model(model.record_filter, config=config) if model.record_filter else None
if client_side_incremental_sync:
record_filter = ClientSideIncrementalRecordFilterDecorator(
config=config,
parameters=model.parameters,
condition=model.record_filter.condition if model.record_filter else None,
**client_side_incremental_sync,
)
schema_normalization = TypeTransformer(SCHEMA_TRANSFORMER_TYPE_MAPPING[model.schema_normalization])

return RecordSelector(
Expand Down Expand Up @@ -1038,10 +1062,16 @@ def create_simple_retriever(
primary_key: Optional[Union[str, List[str], List[List[str]]]],
stream_slicer: Optional[StreamSlicer],
stop_condition_on_cursor: bool = False,
client_side_incremental_sync: Optional[Dict[str, Any]] = None,
transformations: List[RecordTransformation],
) -> SimpleRetriever:
requester = self._create_component_from_model(model=model.requester, config=config, name=name)
record_selector = self._create_component_from_model(model=model.record_selector, config=config, transformations=transformations)
record_selector = self._create_component_from_model(
model=model.record_selector,
config=config,
transformations=transformations,
client_side_incremental_sync=client_side_incremental_sync,
)
url_base = model.requester.url_base if hasattr(model.requester, "url_base") else requester.get_url_base()
stream_slicer = stream_slicer or SinglePartitionRouter(parameters={})
cursor = stream_slicer if isinstance(stream_slicer, DeclarativeCursor) else None
Expand Down
1 change: 1 addition & 0 deletions airbyte-cdk/python/unit_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

# Import the thing that needs to be imported to stop the tests from falling over
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource

# "Use" the thing so that the linter doesn't complain
placeholder = ManifestDeclarativeSource
Loading

0 comments on commit b9a421b

Please sign in to comment.