Skip to content

Commit 9be7df8

Browse files
committed
Ability to pass custom DriverConfig kwargs
1 parent 66bf9e5 commit 9be7df8

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tests/test_connections.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,16 @@ def connection(
372372
finally:
373373
conn.close()
374374

375+
def test_connect_with_custom_driver_config_kwargs(
376+
self, connection_kwargs: dict
377+
) -> None:
378+
connection_kwargs["driver_config_kwargs"] = {
379+
"grpc_keep_alive_timeout": 777,
380+
}
381+
connection = dbapi.connect(**connection_kwargs)
382+
assert connection._driver._driver_config.grpc_keep_alive_timeout == 777
383+
connection.close()
384+
375385
@pytest.mark.parametrize(
376386
("isolation_level", "read_only"),
377387
[
@@ -461,6 +471,17 @@ def close() -> None:
461471

462472
await greenlet_spawn(close)
463473

474+
@pytest.mark.asyncio
475+
async def test_connect_with_custom_driver_config_kwargs(
476+
self, connection_kwargs: dict
477+
) -> None:
478+
connection_kwargs["driver_config_kwargs"] = {
479+
"grpc_keep_alive_timeout": 777,
480+
}
481+
connection = await dbapi.async_connect(**connection_kwargs)
482+
assert connection._driver._driver_config.grpc_keep_alive_timeout == 777
483+
await connection.close()
484+
464485
@pytest.mark.asyncio
465486
@pytest.mark.parametrize(
466487
("isolation_level", "read_only"),

ydb_dbapi/connections.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
ydb_session_pool: SessionPool | AsyncSessionPool | None = None,
8080
root_certificates_path: str | None = None,
8181
root_certificates: str | None = None,
82+
driver_config_kwargs: dict | None = None,
8283
**kwargs: dict,
8384
) -> None:
8485
protocol = protocol if protocol else "grpc"
@@ -89,6 +90,8 @@ def __init__(
8990

9091
self.connection_kwargs: dict = kwargs
9192

93+
driver_config_kwargs = driver_config_kwargs or {}
94+
9295
self._shared_session_pool: bool = False
9396

9497
self._tx_context: TxContext | AsyncTxContext | None = None
@@ -113,6 +116,7 @@ def __init__(
113116
credentials=self.credentials,
114117
query_client_settings=self._get_client_settings(),
115118
root_certificates=root_certificates,
119+
**driver_config_kwargs,
116120
)
117121
self._driver = self._driver_cls(driver_config)
118122
self._session_pool = self._pool_cls(self._driver, size=5)
@@ -197,6 +201,7 @@ def __init__(
197201
ydb_session_pool: SessionPool | AsyncSessionPool | None = None,
198202
root_certificates_path: str | None = None,
199203
root_certificates: str | None = None,
204+
driver_config_kwargs: dict | None = None,
200205
**kwargs: dict,
201206
) -> None:
202207
super().__init__(
@@ -209,6 +214,7 @@ def __init__(
209214
ydb_session_pool=ydb_session_pool,
210215
root_certificates_path=root_certificates_path,
211216
root_certificates=root_certificates,
217+
driver_config_kwargs=driver_config_kwargs,
212218
**kwargs,
213219
)
214220
self._current_cursor: Cursor | None = None
@@ -390,6 +396,7 @@ def __init__(
390396
ydb_session_pool: SessionPool | AsyncSessionPool | None = None,
391397
root_certificates_path: str | None = None,
392398
root_certificates: str | None = None,
399+
driver_config_kwargs: dict | None = None,
393400
**kwargs: dict,
394401
) -> None:
395402
super().__init__(
@@ -402,6 +409,7 @@ def __init__(
402409
ydb_session_pool=ydb_session_pool,
403410
root_certificates_path=root_certificates_path,
404411
root_certificates=root_certificates,
412+
driver_config_kwargs=driver_config_kwargs,
405413
**kwargs,
406414
)
407415
self._current_cursor: AsyncCursor | None = None

0 commit comments

Comments
 (0)