Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): Check for duplicate column names in read_database cursor result, raising DuplicateError if found #18548

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from polars import functions as F
from polars._utils.various import parse_version
from polars.convert import from_arrow
from polars.datatypes import (
N_INFER_DEFAULT,
from polars.datatypes import N_INFER_DEFAULT
from polars.exceptions import (
DuplicateError,
ModuleUpgradeRequiredError,
UnsuitableSQLError,
)
from polars.exceptions import ModuleUpgradeRequiredError, UnsuitableSQLError
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
from polars.io.database._cursor_proxies import ODBCCursorProxy, SurrealDBCursorProxy
from polars.io.database._inference import _infer_dtype_from_cursor_description
Expand Down Expand Up @@ -266,25 +268,25 @@ def _from_rows(
if hasattr(self.result, "fetchall"):
if self.driver_name == "sqlalchemy":
if hasattr(self.result, "cursor"):
cursor_desc = {
d[0]: d[1:] for d in self.result.cursor.description
}
cursor_desc = [
(d[0], d[1:]) for d in self.result.cursor.description
]
elif hasattr(self.result, "_metadata"):
cursor_desc = {k: None for k in self.result._metadata.keys}
cursor_desc = [(k, None) for k in self.result._metadata.keys]
else:
msg = f"Unable to determine metadata from query result; {self.result!r}"
raise ValueError(msg)

elif hasattr(self.result, "description"):
cursor_desc = {d[0]: d[1:] for d in self.result.description}
cursor_desc = [(d[0], d[1:]) for d in self.result.description]
else:
cursor_desc = {}
cursor_desc = []

schema_overrides = self._inject_type_overrides(
description=cursor_desc,
schema_overrides=(schema_overrides or {}),
)
result_columns = list(cursor_desc)
result_columns = [nm for nm, _ in cursor_desc]
frames = (
DataFrame(
data=rows,
Expand All @@ -307,7 +309,7 @@ def _from_rows(

def _inject_type_overrides(
self,
description: dict[str, Any],
description: list[tuple[str, Any]],
schema_overrides: SchemaDict,
) -> SchemaDict:
"""
Expand All @@ -320,11 +322,16 @@ def _inject_type_overrides(
We currently only do the additional inference from string/python type values.
(Further refinement will require per-driver module knowledge and lookups).
"""
for nm, desc in description.items():
if desc is not None and nm not in schema_overrides:
dupe_check = set()
for nm, desc in description:
if nm in dupe_check:
msg = f"column {nm!r} appears more than once in the query/result cursor"
raise DuplicateError(msg)
elif desc is not None and nm not in schema_overrides:
dtype = _infer_dtype_from_cursor_description(self.cursor, desc)
if dtype is not None:
schema_overrides[nm] = dtype # type: ignore[index]
dupe_check.add(nm)

return schema_overrides

Expand Down
19 changes: 18 additions & 1 deletion py-polars/tests/unit/io/database/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import polars as pl
from polars._utils.various import parse_version
from polars.exceptions import UnsuitableSQLError
from polars.exceptions import DuplicateError, UnsuitableSQLError
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
from polars.testing import assert_frame_equal

Expand Down Expand Up @@ -678,6 +678,23 @@ def test_read_database_exceptions(
read_database(**params)


@pytest.mark.parametrize(
"query",
[
"SELECT 1, 1 FROM test_data",
'SELECT 1 AS "n", 2 AS "n" FROM test_data',
'SELECT name, value AS "name" FROM test_data',
],
)
def test_read_database_duplicate_column_error(tmp_sqlite_db: Path, query: str) -> None:
alchemy_conn = create_engine(f"sqlite:///{tmp_sqlite_db}").connect()
with pytest.raises(
DuplicateError,
match="column .+ appears more than once in the query/result cursor",
):
pl.read_database(query, connection=alchemy_conn)


@pytest.mark.parametrize(
"uri",
[
Expand Down