Skip to content

fix(sql): Add fallback to source_defined_primary_key in CatalogProvider #627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions airbyte_cdk/sql/shared/catalog_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,12 @@ def get_primary_keys(
stream_name: str,
) -> list[str]:
"""Return the primary keys for the given stream."""
pks = self.get_configured_stream_info(stream_name).primary_key
if not pks:
return []
configured_stream = self.get_configured_stream_info(stream_name)
pks = (
configured_stream.primary_key
or configured_stream.stream.source_defined_primary_key
or []
)

normalized_pks: list[list[str]] = [
[LowerCaseNormalizer.normalize(c) for c in pk] for pk in pks
Expand Down
160 changes: 160 additions & 0 deletions unit_tests/sql/shared/test_catalog_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from unittest.mock import Mock

import pytest

from airbyte_cdk.models import AirbyteStream, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream
from airbyte_cdk.sql.shared.catalog_providers import CatalogProvider


class TestCatalogProvider:
"""Test cases for CatalogProvider.get_primary_keys() method."""

def test_get_primary_keys_uses_configured_primary_key_when_set(self):
"""Test that configured primary_key is used when set."""
stream = AirbyteStream(
name="test_stream",
json_schema={"type": "object", "properties": {"id": {"type": "string"}}},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=[["source_id"]],
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=[["configured_id"]],
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == ["configured_id"]

def test_get_primary_keys_falls_back_to_source_defined_when_configured_empty(self):
"""Test that source_defined_primary_key is used when primary_key is empty."""
stream = AirbyteStream(
name="test_stream",
json_schema={"type": "object", "properties": {"id": {"type": "string"}}},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=[["source_id"]],
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=[], # Empty configured primary key
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == ["source_id"]

def test_get_primary_keys_falls_back_to_source_defined_when_configured_none(self):
"""Test that source_defined_primary_key is used when primary_key is None."""
stream = AirbyteStream(
name="test_stream",
json_schema={"type": "object", "properties": {"id": {"type": "string"}}},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=[["source_id"]],
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=None, # None configured primary key
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == ["source_id"]

def test_get_primary_keys_returns_empty_when_both_empty(self):
"""Test that empty list is returned when both primary keys are empty."""
stream = AirbyteStream(
name="test_stream",
json_schema={"type": "object", "properties": {"id": {"type": "string"}}},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=[], # Empty source-defined primary key
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=[], # Empty configured primary key
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == []

def test_get_primary_keys_returns_empty_when_both_none(self):
"""Test that empty list is returned when both primary keys are None."""
stream = AirbyteStream(
name="test_stream",
json_schema={"type": "object", "properties": {"id": {"type": "string"}}},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=None, # None source-defined primary key
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=None, # None configured primary key
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == []

def test_get_primary_keys_handles_composite_keys_from_source_defined(self):
"""Test that composite primary keys work correctly with source-defined fallback."""
stream = AirbyteStream(
name="test_stream",
json_schema={
"type": "object",
"properties": {"id1": {"type": "string"}, "id2": {"type": "string"}},
},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=[["id1"], ["id2"]], # Composite primary key
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=[], # Empty configured primary key
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == ["id1", "id2"]

def test_get_primary_keys_normalizes_case_for_source_defined(self):
"""Test that primary keys from source-defined are normalized to lowercase."""
stream = AirbyteStream(
name="test_stream",
json_schema={"type": "object", "properties": {"ID": {"type": "string"}}},
supported_sync_modes=["full_refresh"],
source_defined_primary_key=[["ID"]], # Uppercase primary key
)
configured_stream = ConfiguredAirbyteStream(
stream=stream,
sync_mode="full_refresh",
destination_sync_mode="overwrite",
primary_key=[], # Empty configured primary key
)
catalog = ConfiguredAirbyteCatalog(streams=[configured_stream])

provider = CatalogProvider(catalog)
result = provider.get_primary_keys("test_stream")

assert result == ["id"]
Loading