Skip to content

Commit

Permalink
feat(airbyte-cdk): replace pydantic BaseModel with dataclasses + …
Browse files Browse the repository at this point in the history
…`serpyco-rs` in protocol (#44444)

Signed-off-by: Artem Inzhyyants <artem.inzhyyants@gmail.com>
  • Loading branch information
artem1205 authored Sep 2, 2024
1 parent 21fddbd commit df34893
Show file tree
Hide file tree
Showing 125 changed files with 2,730 additions and 2,270 deletions.
12 changes: 10 additions & 2 deletions airbyte-cdk/python/airbyte_cdk/config_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
from copy import copy
from typing import Any, List, MutableMapping

from airbyte_cdk.models import AirbyteControlConnectorConfigMessage, AirbyteControlMessage, AirbyteMessage, OrchestratorType, Type
from airbyte_cdk.models import (
AirbyteControlConnectorConfigMessage,
AirbyteControlMessage,
AirbyteMessage,
AirbyteMessageSerializer,
OrchestratorType,
Type,
)
from orjson import orjson


class ObservedDict(dict): # type: ignore # disallow_any_generics is set to True, and dict is equivalent to dict[Any]
Expand Down Expand Up @@ -76,7 +84,7 @@ def emit_configuration_as_airbyte_control_message(config: MutableMapping[str, An
See the airbyte_cdk.sources.message package
"""
airbyte_message = create_connector_config_control_message(config)
print(airbyte_message.model_dump_json(exclude_unset=True))
print(orjson.dumps(AirbyteMessageSerializer.dump(airbyte_message)).decode())


def create_connector_config_control_message(config: MutableMapping[str, Any]) -> AirbyteMessage:
Expand Down
4 changes: 2 additions & 2 deletions airbyte-cdk/python/airbyte_cdk/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Generic, Mapping, Optional, Protocol, TypeVar

import yaml
from airbyte_cdk.models import AirbyteConnectionStatus, ConnectorSpecification
from airbyte_cdk.models import AirbyteConnectionStatus, ConnectorSpecification, ConnectorSpecificationSerializer


def load_optional_package_file(package: str, filename: str) -> Optional[bytes]:
Expand Down Expand Up @@ -84,7 +84,7 @@ def spec(self, logger: logging.Logger) -> ConnectorSpecification:
else:
raise FileNotFoundError("Unable to find spec.yaml or spec.json in the package.")

return ConnectorSpecification.parse_obj(spec_obj)
return ConnectorSpecificationSerializer.load(spec_obj)

@abstractmethod
def check(self, logger: logging.Logger, config: TConfig) -> AirbyteConnectionStatus:
Expand Down
15 changes: 11 additions & 4 deletions airbyte-cdk/python/airbyte_cdk/connector_builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@
from airbyte_cdk.connector import BaseConnector
from airbyte_cdk.connector_builder.connector_builder_handler import TestReadLimits, create_source, get_limits, read_stream, resolve_manifest
from airbyte_cdk.entrypoint import AirbyteEntrypoint
from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.models import (
AirbyteMessage,
AirbyteMessageSerializer,
AirbyteStateMessage,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteCatalogSerializer,
)
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from airbyte_cdk.sources.source import Source
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from orjson import orjson


def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str, Any], Optional[ConfiguredAirbyteCatalog], Any]:
Expand All @@ -32,7 +39,7 @@ def get_config_and_catalog_from_args(args: List[str]) -> Tuple[str, Mapping[str,

command = config["__command"]
if command == "test_read":
catalog = ConfiguredAirbyteCatalog.parse_obj(BaseConnector.read_config(catalog_path))
catalog = ConfiguredAirbyteCatalogSerializer.load(BaseConnector.read_config(catalog_path))
state = Source.read_state(state_path)
else:
catalog = None
Expand Down Expand Up @@ -67,7 +74,7 @@ def handle_request(args: List[str]) -> AirbyteMessage:
command, config, catalog, state = get_config_and_catalog_from_args(args)
limits = get_limits(config)
source = create_source(config, limits)
return handle_connector_builder_request(source, command, config, catalog, state, limits).json(exclude_unset=True)
return AirbyteMessageSerializer.dump(handle_connector_builder_request(source, command, config, catalog, state, limits)) # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage


if __name__ == "__main__":
Expand All @@ -76,4 +83,4 @@ def handle_request(args: List[str]) -> AirbyteMessage:
except Exception as exc:
error = AirbyteTracedException.from_exception(exc, message=f"Error handling request: {str(exc)}")
m = error.as_airbyte_message()
print(error.as_airbyte_message().model_dump_json(exclude_unset=True))
print(orjson.dumps(AirbyteMessageSerializer.dump(m)).decode())
34 changes: 17 additions & 17 deletions airbyte-cdk/python/airbyte_cdk/connector_builder/message_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@
StreamReadSlices,
)
from airbyte_cdk.entrypoint import AirbyteEntrypoint
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from airbyte_cdk.sources.utils.types import JsonType
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException
from airbyte_protocol.models.airbyte_protocol import (
from airbyte_cdk.models import (
AirbyteControlMessage,
AirbyteLogMessage,
AirbyteMessage,
Expand All @@ -34,7 +28,13 @@
OrchestratorType,
TraceType,
)
from airbyte_protocol.models.airbyte_protocol import Type as MessageType
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.sources.utils.slice_logger import SliceLogger
from airbyte_cdk.sources.utils.types import JsonType
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException


class MessageGrouper:
Expand Down Expand Up @@ -182,19 +182,19 @@ def _get_message_groups(
if (
at_least_one_page_in_group
and message.type == MessageType.LOG
and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX)
and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
):
yield StreamReadSlices(
pages=current_slice_pages,
slice_descriptor=current_slice_descriptor,
state=[latest_state_message] if latest_state_message else [],
)
current_slice_descriptor = self._parse_slice_description(message.log.message)
current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
current_slice_pages = []
at_least_one_page_in_group = False
elif message.type == MessageType.LOG and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX):
elif message.type == MessageType.LOG and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX): # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
# parsing the first slice
current_slice_descriptor = self._parse_slice_description(message.log.message)
current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
elif message.type == MessageType.LOG:
if json_message is not None and self._is_http_log(json_message):
if self._is_auxiliary_http_request(json_message):
Expand All @@ -221,17 +221,17 @@ def _get_message_groups(
else:
yield message.log
elif message.type == MessageType.TRACE:
if message.trace.type == TraceType.ERROR:
if message.trace.type == TraceType.ERROR: # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has trace.type
yield message.trace
elif message.type == MessageType.RECORD:
current_page_records.append(message.record.data)
current_page_records.append(message.record.data) # type: ignore[union-attr] # AirbyteMessage with MessageType.RECORD has record.data
records_count += 1
schema_inferrer.accumulate(message.record)
datetime_format_inferrer.accumulate(message.record)
elif message.type == MessageType.CONTROL and message.control.type == OrchestratorType.CONNECTOR_CONFIG:
elif message.type == MessageType.CONTROL and message.control.type == OrchestratorType.CONNECTOR_CONFIG: # type: ignore[union-attr] # AirbyteMessage with MessageType.CONTROL has control.type
yield message.control
elif message.type == MessageType.STATE:
latest_state_message = message.state
latest_state_message = message.state # type: ignore[assignment]
else:
if current_page_request or current_page_response or current_page_records:
self._close_page(current_page_request, current_page_response, current_slice_pages, current_page_records)
Expand All @@ -246,7 +246,7 @@ def _need_to_close_page(at_least_one_page_in_group: bool, message: AirbyteMessag
return (
at_least_one_page_in_group
and message.type == MessageType.LOG
and (MessageGrouper._is_page_http_request(json_message) or message.log.message.startswith("slice:"))
and (MessageGrouper._is_page_http_request(json_message) or message.log.message.startswith("slice:")) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
)

@staticmethod
Expand Down
12 changes: 6 additions & 6 deletions airbyte-cdk/python/airbyte_cdk/destinations/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from airbyte_cdk.connector import Connector
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type
from airbyte_cdk.models import AirbyteMessage, AirbyteMessageSerializer, ConfiguredAirbyteCatalog, ConfiguredAirbyteCatalogSerializer, Type
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from pydantic import ValidationError as V2ValidationError
from orjson import orjson

logger = logging.getLogger("airbyte")

Expand All @@ -36,14 +36,14 @@ def _parse_input_stream(self, input_stream: io.TextIOWrapper) -> Iterable[Airbyt
"""Reads from stdin, converting to Airbyte messages"""
for line in input_stream:
try:
yield AirbyteMessage.parse_raw(line)
except V2ValidationError:
yield AirbyteMessageSerializer.load(orjson.loads(line))
except orjson.JSONDecodeError:
logger.info(f"ignoring input which can't be deserialized as Airbyte Message: {line}")

def _run_write(
self, config: Mapping[str, Any], configured_catalog_path: str, input_stream: io.TextIOWrapper
) -> Iterable[AirbyteMessage]:
catalog = ConfiguredAirbyteCatalog.parse_file(configured_catalog_path)
catalog = ConfiguredAirbyteCatalogSerializer.load(orjson.loads(open(configured_catalog_path).read()))
input_messages = self._parse_input_stream(input_stream)
logger.info("Begin writing to the destination...")
yield from self.write(config=config, configured_catalog=catalog, input_messages=input_messages)
Expand Down Expand Up @@ -117,4 +117,4 @@ def run(self, args: List[str]) -> None:
parsed_args = self.parse_args(args)
output_messages = self.run_cmd(parsed_args)
for message in output_messages:
print(message.model_dump_json(exclude_unset=True))
print(orjson.dumps(AirbyteMessageSerializer.dump(message)).decode())
22 changes: 15 additions & 7 deletions airbyte-cdk/python/airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,23 @@
from airbyte_cdk.connector import TConfig
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
from airbyte_cdk.logger import init_logger
from airbyte_cdk.models import AirbyteMessage, FailureType, Status, Type
from airbyte_cdk.models.airbyte_protocol import AirbyteStateStats, ConnectorSpecification # type: ignore [attr-defined]
from airbyte_cdk.models import ( # type: ignore [attr-defined]
AirbyteMessage,
AirbyteMessageSerializer,
AirbyteStateStats,
ConnectorSpecification,
FailureType,
Status,
Type,
)
from airbyte_cdk.sources import Source
from airbyte_cdk.sources.connector_state_manager import HashableStreamDescriptor
from airbyte_cdk.sources.utils.schema_helpers import check_config_against_spec_or_exit, split_config
from airbyte_cdk.utils import PrintBuffer, is_cloud_environment, message_utils
from airbyte_cdk.utils.airbyte_secrets_utils import get_secrets, update_secrets
from airbyte_cdk.utils.constants import ENV_REQUEST_CACHE_PATH
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from orjson import orjson
from requests import PreparedRequest, Response, Session

logger = init_logger("airbyte")
Expand Down Expand Up @@ -170,13 +178,13 @@ def read(self, source_spec: ConnectorSpecification, config: TConfig, catalog: An
def handle_record_counts(message: AirbyteMessage, stream_message_count: DefaultDict[HashableStreamDescriptor, float]) -> AirbyteMessage:
match message.type:
case Type.RECORD:
stream_message_count[HashableStreamDescriptor(name=message.record.stream, namespace=message.record.namespace)] += 1.0
stream_message_count[HashableStreamDescriptor(name=message.record.stream, namespace=message.record.namespace)] += 1.0 # type: ignore[union-attr] # record has `stream` and `namespace`
case Type.STATE:
stream_descriptor = message_utils.get_stream_descriptor(message)

# Set record count from the counter onto the state message
message.state.sourceStats = message.state.sourceStats or AirbyteStateStats()
message.state.sourceStats.recordCount = stream_message_count.get(stream_descriptor, 0.0)
message.state.sourceStats = message.state.sourceStats or AirbyteStateStats() # type: ignore[union-attr] # state has `sourceStats`
message.state.sourceStats.recordCount = stream_message_count.get(stream_descriptor, 0.0) # type: ignore[union-attr] # state has `sourceStats`

# Reset the counter
stream_message_count[stream_descriptor] = 0.0
Expand All @@ -197,8 +205,8 @@ def set_up_secret_filter(config: TConfig, connection_specification: Mapping[str,
update_secrets(config_secrets)

@staticmethod
def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> Any:
return airbyte_message.model_dump_json(exclude_unset=True)
def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str:
return orjson.dumps(AirbyteMessageSerializer.dump(airbyte_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string

@classmethod
def extract_state(cls, args: List[str]) -> Optional[Any]:
Expand Down
17 changes: 9 additions & 8 deletions airbyte-cdk/python/airbyte_cdk/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import logging.config
from typing import Any, Mapping, Optional, Tuple

from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteMessageSerializer, Level, Type
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
from orjson import orjson

LOGGING_CONFIG = {
"version": 1,
Expand Down Expand Up @@ -42,11 +43,11 @@ class AirbyteLogFormatter(logging.Formatter):

# Transforming Python log levels to Airbyte protocol log levels
level_mapping = {
logging.FATAL: "FATAL",
logging.ERROR: "ERROR",
logging.WARNING: "WARN",
logging.INFO: "INFO",
logging.DEBUG: "DEBUG",
logging.FATAL: Level.FATAL,
logging.ERROR: Level.ERROR,
logging.WARNING: Level.WARN,
logging.INFO: Level.INFO,
logging.DEBUG: Level.DEBUG,
}

def format(self, record: logging.LogRecord) -> str:
Expand All @@ -59,8 +60,8 @@ def format(self, record: logging.LogRecord) -> str:
else:
message = super().format(record)
message = filter_secrets(message)
log_message = AirbyteMessage(type="LOG", log=AirbyteLogMessage(level=airbyte_level, message=message))
return log_message.model_dump_json(exclude_unset=True) # type: ignore
log_message = AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message))
return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode() # type: ignore[no-any-return] # orjson.dumps(message).decode() always returns string

@staticmethod
def extract_extra_args_from_record(record: logging.LogRecord) -> Mapping[str, Any]:
Expand Down
10 changes: 10 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# of airbyte-cdk rather than a standalone package.
from .airbyte_protocol import (
AdvancedAuth,
AirbyteStateStats,
AirbyteAnalyticsTraceMessage,
AirbyteCatalog,
AirbyteConnectionStatus,
Expand Down Expand Up @@ -58,3 +59,12 @@
TimeWithoutTimezone,
TimeWithTimezone,
)

from .airbyte_protocol_serializers import (
AirbyteStreamStateSerializer,
AirbyteStateMessageSerializer,
AirbyteMessageSerializer,
ConfiguredAirbyteCatalogSerializer,
ConfiguredAirbyteStreamSerializer,
ConnectorSpecificationSerializer,
)
Loading

0 comments on commit df34893

Please sign in to comment.