Skip to content

fix: (CDK) (ConnectorBuilder) - Add auxiliary requests to slice; support TestRead for AsyncRetriever (part 1/2) #355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions airbyte_cdk/connector_builder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,6 @@ class HttpRequest:
body: Optional[str] = None


@dataclass
class StreamReadPages:
records: List[object]
request: Optional[HttpRequest] = None
response: Optional[HttpResponse] = None


@dataclass
class StreamReadSlices:
pages: List[StreamReadPages]
slice_descriptor: Optional[Dict[str, Any]]
state: Optional[List[Dict[str, Any]]] = None


@dataclass
class LogMessage:
message: str
Expand All @@ -46,11 +32,27 @@ class LogMessage:
@dataclass
class AuxiliaryRequest:
title: str
type: str
description: str
request: HttpRequest
response: HttpResponse


@dataclass
class StreamReadPages:
records: List[object]
request: Optional[HttpRequest] = None
response: Optional[HttpResponse] = None


@dataclass
class StreamReadSlices:
pages: List[StreamReadPages]
slice_descriptor: Optional[Dict[str, Any]]
state: Optional[List[Dict[str, Any]]] = None
auxiliary_requests: Optional[List[AuxiliaryRequest]] = None


@dataclass
class StreamRead(object):
logs: List[LogMessage]
Expand Down
142 changes: 120 additions & 22 deletions airbyte_cdk/connector_builder/test_reader/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
SchemaInferrer,
)

from .types import LOG_MESSAGES_OUTPUT_TYPE
from .types import ASYNC_AUXILIARY_REQUEST_TYPES, LOG_MESSAGES_OUTPUT_TYPE

# -------
# Parsers
Expand Down Expand Up @@ -226,7 +226,8 @@ def should_close_page(
at_least_one_page_in_group
and is_log_message(message)
and (
is_page_http_request(json_message) or message.log.message.startswith("slice:") # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
is_page_http_request(json_message)
or message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
)
)

Expand Down Expand Up @@ -330,6 +331,10 @@ def is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool:
return is_http_log(message) and message.get("http", {}).get("is_auxiliary", False)


def is_async_auxiliary_request(message: AuxiliaryRequest) -> bool:
return message.type in ASYNC_AUXILIARY_REQUEST_TYPES


def is_log_message(message: AirbyteMessage) -> bool:
"""
Determines whether the provided message is of type LOG.
Expand Down Expand Up @@ -413,6 +418,7 @@ def handle_current_slice(
current_slice_pages: List[StreamReadPages],
current_slice_descriptor: Optional[Dict[str, Any]] = None,
latest_state_message: Optional[Dict[str, Any]] = None,
auxiliary_requests: Optional[List[AuxiliaryRequest]] = None,
) -> StreamReadSlices:
"""
Handles the current slice by packaging its pages, descriptor, and state into a StreamReadSlices instance.
Expand All @@ -421,6 +427,7 @@ def handle_current_slice(
current_slice_pages (List[StreamReadPages]): The pages to be included in the slice.
current_slice_descriptor (Optional[Dict[str, Any]]): Descriptor for the current slice, optional.
latest_state_message (Optional[Dict[str, Any]]): The latest state message, optional.
auxiliary_requests (Optional[List[AuxiliaryRequest]]): The auxiliary requests to include, optional.

Returns:
StreamReadSlices: An object containing the current slice's pages, descriptor, and state.
Expand All @@ -429,6 +436,7 @@ def handle_current_slice(
pages=current_slice_pages,
slice_descriptor=current_slice_descriptor,
state=[latest_state_message] if latest_state_message else [],
auxiliary_requests=auxiliary_requests if auxiliary_requests else [],
)


Expand Down Expand Up @@ -486,29 +494,24 @@ def handle_auxiliary_request(json_message: Dict[str, JsonType]) -> AuxiliaryRequ
Raises:
ValueError: If any of the "airbyte_cdk", "stream", or "http" fields is not a dictionary.
"""
airbyte_cdk = json_message.get("airbyte_cdk", {})

if not isinstance(airbyte_cdk, dict):
raise ValueError(
f"Expected airbyte_cdk to be a dict, got {airbyte_cdk} of type {type(airbyte_cdk)}"
)

stream = airbyte_cdk.get("stream", {})

if not isinstance(stream, dict):
raise ValueError(f"Expected stream to be a dict, got {stream} of type {type(stream)}")
airbyte_cdk = get_airbyte_cdk_from_message(json_message)
stream = get_stream_from_airbyte_cdk(airbyte_cdk)
title_prefix = get_auxiliary_request_title_prefix(stream)
http = get_http_property_from_message(json_message)
request_type = get_auxiliary_request_type(stream, http)

title_prefix = "Parent stream: " if stream.get("is_substream", False) else ""
http = json_message.get("http", {})

if not isinstance(http, dict):
raise ValueError(f"Expected http to be a dict, got {http} of type {type(http)}")
title = title_prefix + str(http.get("title", None))
description = str(http.get("description", None))
request = create_request_from_log_message(json_message)
response = create_response_from_log_message(json_message)

return AuxiliaryRequest(
title=title_prefix + str(http.get("title", None)),
description=str(http.get("description", None)),
request=create_request_from_log_message(json_message),
response=create_response_from_log_message(json_message),
title=title,
type=request_type,
description=description,
request=request,
response=response,
)


Expand Down Expand Up @@ -558,7 +561,8 @@ def handle_log_message(
at_least_one_page_in_group,
current_page_request,
current_page_response,
auxiliary_request or log_message,
auxiliary_request,
log_message,
)


Expand Down Expand Up @@ -589,3 +593,97 @@ def handle_record_message(
datetime_format_inferrer.accumulate(message.record) # type: ignore

return records_count


# -------
# Reusable Getters
# -------


def get_airbyte_cdk_from_message(json_message: Dict[str, JsonType]) -> dict: # type: ignore
"""
Retrieves the "airbyte_cdk" dictionary from the provided JSON message.

This function validates that the extracted "airbyte_cdk" is of type dict,
raising a ValueError if the validation fails.

Parameters:
json_message (Dict[str, JsonType]): A dictionary representing the JSON message.

Returns:
dict: The "airbyte_cdk" dictionary extracted from the JSON message.

Raises:
ValueError: If the "airbyte_cdk" field is not a dictionary.
"""
airbyte_cdk = json_message.get("airbyte_cdk", {})

if not isinstance(airbyte_cdk, dict):
raise ValueError(
f"Expected airbyte_cdk to be a dict, got {airbyte_cdk} of type {type(airbyte_cdk)}"
)

return airbyte_cdk


def get_stream_from_airbyte_cdk(airbyte_cdk: dict) -> dict: # type: ignore
"""
Retrieves the "stream" dictionary from the provided "airbyte_cdk" dictionary.

This function ensures that the extracted "stream" is of type dict,
raising a ValueError if the validation fails.

Parameters:
airbyte_cdk (dict): The dictionary representing the Airbyte CDK data.

Returns:
dict: The "stream" dictionary extracted from the Airbyte CDK data.

Raises:
ValueError: If the "stream" field is not a dictionary.
"""

stream = airbyte_cdk.get("stream", {})

if not isinstance(stream, dict):
raise ValueError(f"Expected stream to be a dict, got {stream} of type {type(stream)}")

return stream


def get_auxiliary_request_title_prefix(stream: dict) -> str: # type: ignore
"""
Generates a title prefix based on the stream type.
"""
return "Parent stream: " if stream.get("is_substream", False) else ""


def get_http_property_from_message(json_message: Dict[str, JsonType]) -> dict: # type: ignore
"""
Retrieves the "http" dictionary from the provided JSON message.

This function validates that the extracted "http" is of type dict,
raising a ValueError if the validation fails.

Parameters:
json_message (Dict[str, JsonType]): A dictionary representing the JSON message.

Returns:
dict: The "http" dictionary extracted from the JSON message.

Raises:
ValueError: If the "http" field is not a dictionary.
"""
http = json_message.get("http", {})

if not isinstance(http, dict):
raise ValueError(f"Expected http to be a dict, got {http} of type {type(http)}")

return http


def get_auxiliary_request_type(stream: dict, http: dict) -> str: # type: ignore
"""
Determines the type of the auxiliary request based on the stream and HTTP properties.
"""
return "PARENT_STREAM" if stream.get("is_substream", False) else str(http.get("type", None))
19 changes: 16 additions & 3 deletions airbyte_cdk/connector_builder/test_reader/message_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Dict, Iterator, List, Mapping, Optional

from airbyte_cdk.connector_builder.models import (
AuxiliaryRequest,
HttpRequest,
HttpResponse,
StreamReadPages,
Expand All @@ -24,6 +25,7 @@
handle_current_slice,
handle_log_message,
handle_record_message,
is_async_auxiliary_request,
is_config_update_message,
is_log_message,
is_record_message,
Expand Down Expand Up @@ -89,6 +91,7 @@ def get_message_groups(
current_page_request: Optional[HttpRequest] = None
current_page_response: Optional[HttpResponse] = None
latest_state_message: Optional[Dict[str, Any]] = None
slice_auxiliary_requests: List[AuxiliaryRequest] = []

while records_count < limit and (message := next(messages, None)):
json_message = airbyte_message_to_json(message)
Expand All @@ -106,6 +109,7 @@ def get_message_groups(
current_slice_pages,
current_slice_descriptor,
latest_state_message,
slice_auxiliary_requests,
)
current_slice_descriptor = parse_slice_description(message.log.message) # type: ignore
current_slice_pages = []
Expand All @@ -118,16 +122,24 @@ def get_message_groups(
at_least_one_page_in_group,
current_page_request,
current_page_response,
log_or_auxiliary_request,
auxiliary_request,
log_message,
) = handle_log_message(
message,
json_message,
at_least_one_page_in_group,
current_page_request,
current_page_response,
)
if log_or_auxiliary_request:
yield log_or_auxiliary_request

if auxiliary_request:
if is_async_auxiliary_request(auxiliary_request):
slice_auxiliary_requests.append(auxiliary_request)
else:
yield auxiliary_request

if log_message:
yield log_message
elif is_trace_with_error(message):
if message.trace is not None:
yield message.trace
Expand Down Expand Up @@ -157,4 +169,5 @@ def get_message_groups(
current_slice_pages,
current_slice_descriptor,
latest_state_message,
slice_auxiliary_requests,
)
10 changes: 9 additions & 1 deletion airbyte_cdk/connector_builder/test_reader/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,13 @@
bool,
HttpRequest | None,
HttpResponse | None,
AuxiliaryRequest | AirbyteLogMessage | None,
AuxiliaryRequest | None,
AirbyteLogMessage | None,
]

ASYNC_AUXILIARY_REQUEST_TYPES = [
"ASYNC_CREATE",
"ASYNC_POLL",
"ASYNC_ABORT",
"ASYNC_DELETE",
]
1 change: 1 addition & 0 deletions airbyte_cdk/sources/declarative/auth/token_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _refresh(self) -> None:
"Obtains session token",
None,
is_auxiliary=True,
type="AUTH",
),
)
if response is None:
Expand Down
Loading
Loading