Skip to content

fix(cdk): determine state from manager if not received a state in per partition router #544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,7 @@ def create_concurrent_cursor_from_perpartition_cursor(
stream_state_migrations=stream_state_migrations,
)
)

stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)
# Per-partition state doesn't make sense for GroupingPartitionRouter, so force the global state
use_global_cursor = isinstance(
Expand Down Expand Up @@ -1993,14 +1994,19 @@ def _build_incremental_cursor(
) -> Optional[StreamSlicer]:
if model.incremental_sync and stream_slicer:
if model.retriever.type == "AsyncRetriever":
stream_name = model.name or ""
stream_namespace = None
stream_state = self._connector_state_manager.get_stream_state(
stream_name, stream_namespace
)
return self.create_concurrent_cursor_from_perpartition_cursor( # type: ignore # This is a known issue that we are creating and returning a ConcurrentCursor which does not technically implement the (low-code) StreamSlicer. However, (low-code) StreamSlicer and ConcurrentCursor both implement StreamSlicer.stream_slices() which is the primary method needed for checkpointing
state_manager=self._connector_state_manager,
model_type=DatetimeBasedCursorModel,
component_definition=model.incremental_sync.__dict__,
stream_name=model.name or "",
stream_namespace=None,
stream_name=stream_name,
stream_namespace=stream_namespace,
config=config or {},
stream_state={},
stream_state=stream_state,
partition_router=stream_slicer,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
decoder:
type: JsonDecoder
extractor:
type: DpathExtractor
selector:
type: RecordSelector
record_filter:
type: RecordFilter
condition: "{{ record['id'] > stream_state['id'] }}"
requester:
type: HttpRequester
name: "{{ parameters['name'] }}"
url_base: "https://api.sendgrid.com/v3/"
http_method: "GET"
authenticator:
type: SessionTokenAuthenticator
decoder:
type: JsonDecoder
expiration_duration: P10D
login_requester:
path: /session
type: HttpRequester
url_base: 'https://api.sendgrid.com'
http_method: POST
request_body_json:
password: '{{ config.apikey }}'
username: '{{ parameters.name }}'
session_token_path:
- id
request_authentication:
type: ApiKey
inject_into:
type: RequestOption
field_name: X-Metabase-Session
inject_into: header
request_parameters:
unit: "day"
list_stream:
type: DeclarativeStream
name: lists
schema_loader:
type: JsonFileSchemaLoader
file_path: "./source_sendgrid/schemas/{{ parameters.name }}.json"
incremental_sync:
type: DatetimeBasedCursor
$parameters:
datetime_format: "%Y-%m-%dT%H:%M:%S%z"
start_datetime:
type: MinMaxDatetime
datetime: "{{ config['reports_start_date'] }}"
datetime_format: "%Y-%m-%d"
end_datetime:
type: MinMaxDatetime
datetime: "{{ format_datetime(now_utc(), '%Y-%m-%d') }}"
datetime_format: "%Y-%m-%d"
cursor_field: TimePeriod
cursor_datetime_formats:
- "%Y-%m-%dT%H:%M:%S%z"
retriever:
type: AsyncRetriever
name: "{{ parameters['name'] }}"
decoder:
$ref: "#/decoder"
partition_router:
type: ListPartitionRouter
values: "{{config['repos']}}"
cursor_field: a_key
request_option:
inject_into: header
field_name: a_key
status_mapping:
failed:
- Error
running:
- Pending
completed:
- Success
timeout: [ ]
status_extractor:
type: DpathExtractor
field_path:
- ReportRequestStatus
- Status
download_target_extractor:
type: DpathExtractor
field_path:
- ReportRequestStatus
- ReportDownloadUrl
creation_requester:
type: HttpRequester
url_base: https://reporting.api.bingads.microsoft.com/
path: Reporting/v13/GenerateReport/Submit
http_method: POST
request_headers:
Content-Type: application/json
DeveloperToken: "{{ config['developer_token'] }}"
CustomerId: "'{{ stream_partition['customer_id'] }}'"
CustomerAccountId: "'{{ stream_partition['account_id'] }}'"
request_body_json:
ReportRequest:
ExcludeColumnHeaders: false
polling_requester:
type: HttpRequester
url_base: https://fakerporting.api.bingads.microsoft.com/
path: Reporting/v13/GenerateReport/Poll
http_method: POST
request_headers:
Content-Type: application/json
DeveloperToken: "{{ config['developer_token'] }}"
CustomerId: "'{{ stream_partition['customer_id'] }}'"
CustomerAccountId: "'{{ stream_partition['account_id'] }}'"
request_body_json:
ReportRequestId: "'{{ creation_response['ReportRequestId'] }}'"
download_requester:
type: HttpRequester
url_base: "{{download_target}}"
http_method: GET
paginator:
type: DefaultPaginator
page_size_option:
inject_into: request_parameter
field_name: page_size
page_token_option:
inject_into: path
type: RequestPath
pagination_strategy:
type: "CursorPagination"
cursor_value: "{{ response._metadata.next }}"
page_size: 10
requester:
$ref: "#/requester"
path: "{{ next_page_token['next_page_url'] }}"
record_selector:
$ref: "#/selector"
$parameters:
name: "lists"
primary_key: "id"
extractor:
$ref: "#/extractor"
field_path: ["{{ parameters['name'] }}"]
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
from copy import deepcopy

# mypy: ignore-errors
from datetime import datetime, timedelta, timezone
from typing import Any, Iterable, Mapping
from pathlib import Path
from typing import Any, Iterable, Mapping, Optional, Union

import freezegun
import pytest
import requests
from freezegun.api import FakeDatetime
from pydantic.v1 import ValidationError

from airbyte_cdk import AirbyteTracedException
Expand Down Expand Up @@ -42,6 +45,7 @@
ClientSideIncrementalRecordFilterDecorator,
)
from airbyte_cdk.sources.declarative.incremental import (
ConcurrentPerPartitionCursor,
CursorFactory,
DatetimeBasedCursor,
PerPartitionCursor,
Expand Down Expand Up @@ -166,7 +170,7 @@
MonthClampingStrategy,
WeekClampingStrategy,
)
from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor
from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField
from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import (
CustomFormatConcurrentStreamStateConverter,
)
Expand All @@ -190,6 +194,21 @@
input_config = {"apikey": "verysecrettoken", "repos": ["airbyte", "airbyte-cloud"]}


def get_factory_with_parameters(
connector_state_manager: Optional[ConnectorStateManager] = None,
) -> ModelToComponentFactory:
return ModelToComponentFactory(
connector_state_manager=connector_state_manager,
)


def read_yaml_file(resource_path: Union[str, Path]) -> str:
yaml_path = Path(__file__).parent / resource_path
with open(yaml_path, "r") as file:
content = file.read()
return content


def test_create_check_stream():
manifest = {"check": {"type": "CheckStream", "stream_names": ["list_stream"]}}

Expand Down Expand Up @@ -925,6 +944,97 @@ def test_stream_with_incremental_and_retriever_with_partition_router():
assert list_stream_slicer._cursor_field.string == "a_key"


@freezegun.freeze_time("2025-05-14")
def test_stream_with_incremental_and_async_retriever_with_partition_router():
content = read_yaml_file(
"resources/stream_with_incremental_and_aync_retriever_with_partition_router.yaml"
)
parsed_manifest = YamlDeclarativeSource._parse(content)
resolved_manifest = resolver.preprocess_manifest(parsed_manifest)
stream_manifest = transformer.propagate_types_and_parameters(
"", resolved_manifest["list_stream"], {}
)
cursor_time_period_value = "2025-05-06T12:00:00+0000"
cursor_field_key = "TimePeriod"
parent_user_id = "102023653"
per_partition_key = {
"account_id": 999999999,
"parent_slice": {"parent_slice": {}, "user_id": parent_user_id},
}
stream_state = {
"use_global_cursor": False,
"states": [
{"partition": per_partition_key, "cursor": {cursor_field_key: cursor_time_period_value}}
],
"state": {cursor_field_key: "2025-05-12T12:00:00+0000"},
"lookback_window": 46,
}
connector_state_manager = ConnectorStateManager(
state=[
AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="lists"),
stream_state=AirbyteStateBlob(stream_state),
),
)
]
)

factory_with_parameters = get_factory_with_parameters(
connector_state_manager=connector_state_manager
)
connector_config = deepcopy(input_config)
connector_config["reports_start_date"] = "2025-01-01"
stream = factory_with_parameters.create_component(
model_type=DeclarativeStreamModel,
component_definition=stream_manifest,
config=connector_config,
)

assert isinstance(stream, DeclarativeStream)
assert isinstance(stream.retriever, AsyncRetriever)
stream_slicer = stream.retriever.stream_slicer.stream_slicer
assert isinstance(stream_slicer, ConcurrentPerPartitionCursor)
assert stream_slicer.state == stream_state
import json

cursor_perpartition = stream_slicer._cursor_per_partition
expected_cursor_perpartition_key = json.dumps(per_partition_key, sort_keys=True).replace(
" ", ""
)
assert (
cursor_perpartition[expected_cursor_perpartition_key].cursor_field.cursor_field_key
== cursor_field_key
)
assert cursor_perpartition[expected_cursor_perpartition_key].start == datetime(
2025, 5, 6, 12, 0, tzinfo=timezone.utc
)
assert (
cursor_perpartition[expected_cursor_perpartition_key].state[cursor_field_key]
== cursor_time_period_value
)

concurrent_cursor = cursor_perpartition[expected_cursor_perpartition_key]
assert concurrent_cursor._concurrent_state == {
"legacy": {cursor_field_key: cursor_time_period_value},
"slices": [
{
"end": FakeDatetime(2025, 5, 6, 12, 0, tzinfo=timezone.utc),
"most_recent_cursor_value": FakeDatetime(2025, 5, 6, 12, 0, tzinfo=timezone.utc),
"start": FakeDatetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc),
}
],
"state_type": "date-range",
}

stream_slices = list(concurrent_cursor.stream_slices())
expected_stream_slices = [
{"start_time": cursor_time_period_value, "end_time": "2025-05-14T00:00:00+0000"}
]
assert stream_slices == expected_stream_slices


def test_resumable_full_refresh_stream():
content = """
decoder:
Expand Down
Loading