Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions adbc_drivers_validation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class DriverSetup(BaseModel):

class DriverFeatures(BaseModel):
connection_get_table_schema: bool = Field(default=False)
connection_get_statistics: bool = Field(default=False)
connection_set_current_catalog: bool = Field(default=False)
connection_set_current_schema: bool = Field(default=False)
connection_transactions: bool = Field(default=False)
Expand Down
203 changes: 203 additions & 0 deletions adbc_drivers_validation/tests/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,51 @@
from adbc_drivers_validation import compare, model
from adbc_drivers_validation.utils import scoped_trace

# Expected schema for GetStatistics (ADBC spec)
# Built up from innermost to outermost types

_STATISTIC_VALUE_TYPE = pyarrow.dense_union(
[
pyarrow.field("int64", pyarrow.int64()),
pyarrow.field("uint64", pyarrow.uint64()),
pyarrow.field("float64", pyarrow.float64()),
pyarrow.field("binary", pyarrow.binary()),
],
type_codes=[0, 1, 2, 3],
)

_STATISTICS_STRUCT = pyarrow.struct(
[
pyarrow.field("table_name", pyarrow.string(), nullable=False),
pyarrow.field("column_name", pyarrow.string()),
pyarrow.field("statistic_key", pyarrow.int16(), nullable=False),
pyarrow.field("statistic_value", _STATISTIC_VALUE_TYPE, nullable=False),
pyarrow.field("statistic_is_approximate", pyarrow.bool_(), nullable=False),
]
)

_DB_SCHEMA_STRUCT = pyarrow.struct(
[
pyarrow.field("db_schema_name", pyarrow.string()),
pyarrow.field(
"db_schema_statistics",
pyarrow.list_(_STATISTICS_STRUCT),
nullable=False,
),
]
)

_EXPECTED_GET_STATISTICS_SCHEMA = pyarrow.schema(
[
pyarrow.field("catalog_name", pyarrow.string()),
pyarrow.field(
"catalog_db_schemas",
pyarrow.list_(_DB_SCHEMA_STRUCT),
nullable=False,
),
]
)


def generate_tests(
all_quirks: list[model.DriverQuirks], metafunc: pytest.Metafunc
Expand All @@ -56,6 +101,9 @@ def generate_tests(
"test_get_objects_"
):
marks.append(pytest.mark.xfail(reason="not implemented"))
elif metafunc.definition.name == "test_get_statistics":
if not f.connection_get_statistics:
marks.append(pytest.mark.skip(reason="not implemented"))

combinations.append(pytest.param(driver_param, id=driver_param, marks=marks))
metafunc.parametrize(
Expand Down Expand Up @@ -748,6 +796,64 @@ def get_objects_constraints(
with scoped_trace(stmt):
cursor.execute(stmt)

@pytest.fixture(scope="function")
def get_statistics_table(
self,
driver: model.DriverQuirks,
conn: adbc_driver_manager.dbapi.Connection,
) -> typing.Generator[tuple[str | None, str | None, str], None, None]:
"""Fixture that creates a table with test data for GetStatistics tests."""
table_name = f"statistics{secrets.token_hex(8)}"

with conn.cursor() as cursor:
driver.try_drop_table(cursor, table_name=table_name)

# Create and populate table
if driver.features.statement_bulk_ingest:
# Use bulk ingest if available
schema = pyarrow.schema(
[
("id", pyarrow.int32()),
("name", pyarrow.string()),
("value", pyarrow.float64()),
]
)
data = pyarrow.Table.from_pydict(
{
"id": [1, 2, 3],
"name": ["foo", "bar", None],
"value": [1.5, 2.5, 3.5],
},
schema=schema,
)
cursor.adbc_ingest(table_name, data)
else:
# Fall back to CREATE TABLE + INSERT
quoted = driver.quote_identifier(table_name)
cursor.execute(
f"CREATE TABLE {quoted} (id INT, name VARCHAR(100), value REAL)"
)
cursor.execute(
f"INSERT INTO {quoted} (id, name, value) VALUES (1, 'foo', 1.5)"
)
cursor.execute(
f"INSERT INTO {quoted} (id, name, value) VALUES (2, 'bar', 2.5)"
)
cursor.execute(
f"INSERT INTO {quoted} (id, name, value) VALUES (3, NULL, 3.5)"
)

table_id = (
driver.features.current_catalog,
driver.features.current_schema,
table_name,
)

yield table_id

with conn.cursor() as cursor:
driver.try_drop_table(cursor, table_name=table_name)

def get_constraints(
self,
driver: model.DriverQuirks,
Expand Down Expand Up @@ -962,6 +1068,103 @@ def test_get_objects_constraints_unique(
else:
assert constraints[1]["constraint_column_names"] == ["c", "b"]

def test_get_statistics(
self,
driver: model.DriverQuirks,
conn: adbc_driver_manager.dbapi.Connection,
get_statistics_table: tuple[str, str, str],
) -> None:
"""Test GetStatistics"""
assert hasattr(conn, "adbc_get_statistics"), (
"Driver claims to support GetStatistics but adbc_driver_manager DBAPI does not expose it"
)

table_id = get_statistics_table
table_name = table_id[-1]

# Call GetStatistics with all filters
reader = conn.adbc_get_statistics(
catalog_filter=driver.features.current_catalog,
db_schema_filter=driver.features.current_schema,
table_name_filter=table_name,
approximate=True,
)
table = reader.read_all()

# Verify schema matches ADBC spec
assert table.schema.equals(_EXPECTED_GET_STATISTICS_SCHEMA), (
"GetStatistics returned schema does not match ADBC specification"
)

# Find and verify table statistics
table_found = False
table_stats = []
column_stats = {}

for row in table.to_pylist():
# Verify catalog name
if driver.features.current_catalog:
assert row["catalog_name"] == driver.features.current_catalog, (
f"Expected catalog {driver.features.current_catalog}, got {row['catalog_name']}"
)

for sch in row["catalog_db_schemas"]:
# Verify schema name
if driver.features.current_schema:
assert sch["db_schema_name"] == driver.features.current_schema, (
f"Expected schema {driver.features.current_schema}, got {sch['db_schema_name']}"
)

for stat in sch["db_schema_statistics"]:
if stat["table_name"] == table_name:
table_found = True

# Organize statistics by table vs column
if stat["column_name"] is None:
# Table-level statistic
table_stats.append(stat)
else:
# Column-level statistic
col_name = stat["column_name"]
if col_name not in column_stats:
column_stats[col_name] = []
column_stats[col_name].append(stat)

# Verify table was found
assert table_found, f"Table {table_name} not found in statistics results"

# Validate statistics structure if any are present
all_stats = table_stats + [s for stats in column_stats.values() for s in stats]

for stat in all_stats:
# Verify statistic key is valid. Values in [0, 1024) are reserved for ADBC
assert 0 <= stat["statistic_key"] <= 1024, (
f"Invalid statistic key: {stat['statistic_key']} (must be 0-1024)"
)

# If row count statistic is present, verify it's reasonable since approx = true

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this works OK to start but it may be good to extend DriverFeatures with a list of expected statistic names, sort of like with xdbc

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. we can merge this first and tackle this later with seperate PR.

row_count_stat = next(
(s for s in table_stats if s["statistic_key"] == 6),
None,
)
if row_count_stat is not None:
row_count_value = row_count_stat["statistic_value"]
assert row_count_value >= 3, (
f"Expected at least 3 rows, got {row_count_value}"
)

# If null count for 'name' column is present, verify it's > 0
if "name" in column_stats:
null_count_stat = next(
(s for s in column_stats["name"] if s["statistic_key"] == 5),
None,
)
if null_count_stat:
null_count = null_count_stat["statistic_value"]
assert null_count >= 1, (
f"Expected at least 1 null in 'name' column, got {null_count}"
)

def test_repl(
self,
driver: model.DriverQuirks,
Expand Down
Loading