Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import enum
import functools
import typing

import adbc_driver_manager

Expand Down Expand Up @@ -53,9 +54,15 @@ class StatementOptions(enum.Enum):
USE_COPY = "adbc.postgresql.use_copy"


def connect(uri: str) -> adbc_driver_manager.AdbcDatabase:
def connect(
uri: str,
db_kwargs: typing.Optional[typing.Dict[str, str]] = None,
) -> adbc_driver_manager.AdbcDatabase:
"""Create a low level ADBC connection to PostgreSQL."""
return adbc_driver_manager.AdbcDatabase(driver=_driver_path(), uri=uri)
db_options = dict(db_kwargs or {})
db_options["driver"] = _driver_path()
db_options["uri"] = uri
return adbc_driver_manager.AdbcDatabase(**db_options)


@functools.lru_cache
Expand Down
10 changes: 6 additions & 4 deletions python/adbc_driver_postgresql/adbc_driver_postgresql/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def connect(
uri: str,
db_kwargs: typing.Optional[typing.Dict[str, str]] = None,
conn_kwargs: typing.Optional[typing.Dict[str, str]] = None,
**kwargs
**kwargs,
) -> "Connection":
"""
Connect to PostgreSQL via ADBC.
Expand All @@ -118,9 +118,11 @@ def connect(
conn = None

try:
db = adbc_driver_postgresql.connect(uri)
conn = adbc_driver_manager.AdbcConnection(db)
return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
db = adbc_driver_postgresql.connect(uri, db_kwargs=db_kwargs)
conn = adbc_driver_manager.AdbcConnection(db, **(conn_kwargs or {}))
return adbc_driver_manager.dbapi.Connection(
db, conn, conn_kwargs=conn_kwargs, **kwargs
)
except Exception:
if conn:
conn.close()
Expand Down
14 changes: 14 additions & 0 deletions python/adbc_driver_postgresql/tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,17 @@ def status() -> str:
assert status() == "active"
postgres.rollback()
assert status() == "intrans"


def test_connect_conn_kwargs_db_schema(postgres_uri: str, postgres: dbapi.Connection):
"""Verify current DB schema can be set via conn_kwargs."""
schema_key = "adbc.connection.db_schema"
schema_name = "dbapi_test_schema_via_option"

with postgres.cursor() as cur:
cur.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE")
cur.execute(f"CREATE SCHEMA {schema_name}")
postgres.commit()
with dbapi.connect(postgres_uri, conn_kwargs={schema_key: schema_name}) as conn:
option_value = conn.adbc_connection.get_option(schema_key)
assert option_value == schema_name