Skip to content

Commit 68be308

Browse files
SNOW-2671717: Async connect context manager support - approach 1 - idempotent __aenter__ through checking if conn closed (#2614)
1 parent b151e77 commit 68be308

File tree

5 files changed

+335
-17
lines changed

5 files changed

+335
-17
lines changed
Lines changed: 151 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from functools import wraps
4+
from typing import Any, Coroutine, Generator, Protocol, TypeVar, runtime_checkable
5+
36
from ._connection import SnowflakeConnection
47
from ._cursor import DictCursor, SnowflakeCursor
58

@@ -9,8 +12,152 @@
912
DictCursor,
1013
]
1114

15+
# ============================================================================
16+
# DESIGN NOTES:
17+
#
18+
# Pattern similar to aiohttp.ClientSession.request() which similarly returns
19+
# an object that can be both awaited and used as an async context manager.
20+
#
21+
# The async connect function uses a wrapper to support both:
22+
# 1. Direct awaiting: conn = await connect(...)
23+
# 2. Async context manager: async with connect(...) as conn:
24+
#
25+
# connect: A function decorated with @wraps(SnowflakeConnection.__init__) that
26+
# preserves metadata for IDE support, type checking, and introspection.
27+
# Returns a _AsyncConnectContextManager instance when called.
28+
#
29+
# _AsyncConnectContextManager: Implements __await__ and __aenter__/__aexit__
30+
# to support both patterns on the same awaitable.
31+
#
32+
# The @wraps decorator ensures that connect() has the same signature and
33+
# documentation as SnowflakeConnection.__init__, making it behave identically
34+
# to the sync snowflake.connector.connect function from an introspection POV.
35+
#
36+
# Metadata preservation is critical for IDE autocomplete, static type checkers,
37+
# and documentation generation to work correctly on the async connect function.
38+
# ============================================================================
39+
40+
41+
T = TypeVar("T")
42+
43+
44+
@runtime_checkable
45+
class HybridCoroutineContextManager(Protocol[T]):
46+
"""Protocol for a hybrid coroutine that is also an async context manager.
47+
48+
Combines the full coroutine protocol (PEP 492) with async context manager
49+
protocol (PEP 343/492), allowing code that expects either interface to work
50+
seamlessly with instances of this protocol.
51+
52+
This is used when external code needs to manage the coroutine lifecycle
53+
(e.g., timeout handlers, async schedulers) or use it as a context manager.
54+
"""
55+
56+
# Full Coroutine Protocol (PEP 492)
57+
def send(self, __arg: Any) -> Any:
58+
"""Send a value into the coroutine."""
59+
...
60+
61+
def throw(
62+
self,
63+
__typ: type[BaseException],
64+
__val: BaseException | None = None,
65+
__tb: Any = None,
66+
) -> Any:
67+
"""Throw an exception into the coroutine."""
68+
...
69+
70+
def close(self) -> None:
71+
"""Close the coroutine."""
72+
...
73+
74+
def __await__(self) -> Generator[Any, None, T]:
75+
"""Return awaitable generator."""
76+
...
77+
78+
def __iter__(self) -> Generator[Any, None, T]:
79+
"""Iterate over the coroutine."""
80+
...
81+
82+
# Async Context Manager Protocol (PEP 343)
83+
async def __aenter__(self) -> T:
84+
"""Async context manager entry."""
85+
...
86+
87+
async def __aexit__(
88+
self,
89+
__exc_type: type[BaseException] | None,
90+
__exc_val: BaseException | None,
91+
__exc_tb: Any,
92+
) -> bool | None:
93+
"""Async context manager exit."""
94+
...
95+
96+
97+
class _AsyncConnectContextManager(HybridCoroutineContextManager[SnowflakeConnection]):
98+
"""Hybrid wrapper that enables both awaiting and async context manager usage.
99+
100+
Allows both patterns:
101+
- conn = await connect(...)
102+
- async with connect(...) as conn:
103+
104+
Implements the full coroutine protocol for maximum compatibility.
105+
Satisfies the HybridCoroutineContextManager protocol.
106+
"""
107+
108+
__slots__ = ("_coro", "_conn")
109+
110+
def __init__(self, coro: Coroutine[Any, Any, SnowflakeConnection]) -> None:
111+
self._coro = coro
112+
self._conn: SnowflakeConnection | None = None
113+
114+
def send(self, arg: Any) -> Any:
115+
"""Send a value into the wrapped coroutine."""
116+
return self._coro.send(arg)
117+
118+
def throw(self, *args: Any, **kwargs: Any) -> Any:
119+
"""Throw an exception into the wrapped coroutine."""
120+
return self._coro.throw(*args, **kwargs)
121+
122+
def close(self) -> None:
123+
"""Close the wrapped coroutine."""
124+
return self._coro.close()
125+
126+
def __await__(self) -> Generator[Any, None, SnowflakeConnection]:
127+
"""Enable await connect(...)"""
128+
return self._coro.__await__()
129+
130+
def __iter__(self) -> Generator[Any, None, SnowflakeConnection]:
131+
"""Make the wrapper iterable like a coroutine."""
132+
return self.__await__()
133+
134+
# This approach requires idempotent __aenter__ of SnowflakeConnection class - so check if connected and do not repeat connecting
135+
async def __aenter__(self) -> SnowflakeConnection:
136+
"""Enable async with connect(...) as conn:"""
137+
self._conn = await self._coro
138+
return await self._conn.__aenter__()
139+
140+
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
141+
"""Exit async context manager."""
142+
if self._conn is not None:
143+
return await self._conn.__aexit__(exc_type, exc, tb)
144+
else:
145+
return None
146+
147+
148+
@wraps(SnowflakeConnection.__init__)
149+
def connect(**kwargs: Any) -> HybridCoroutineContextManager[SnowflakeConnection]:
150+
"""Create and connect to a Snowflake connection asynchronously.
151+
152+
Returns an awaitable that can also be used as an async context manager.
153+
Supports both patterns:
154+
- conn = await connect(...)
155+
- async with connect(...) as conn:
156+
"""
157+
158+
async def _connect_coro() -> SnowflakeConnection:
159+
conn = SnowflakeConnection(**kwargs)
160+
await conn.connect()
161+
return conn
12162

13-
async def connect(**kwargs) -> SnowflakeConnection:
14-
conn = SnowflakeConnection(**kwargs)
15-
await conn.connect()
16-
return conn
163+
return _AsyncConnectContextManager(_connect_coro())

src/snowflake/connector/aio/_connection.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,21 @@ def __init__(
123123
connections_file_path: pathlib.Path | None = None,
124124
**kwargs,
125125
) -> None:
126+
"""Create a new SnowflakeConnection.
127+
128+
Connections can be loaded from the TOML file located at
129+
snowflake.connector.constants.CONNECTIONS_FILE.
130+
131+
When connection_name is supplied we will first load that connection
132+
and then override any other values supplied.
133+
134+
When no arguments are given (other than connection_file_path) the
135+
default connection will be loaded first. Note that no overwriting is
136+
supported in this case.
137+
138+
If overwriting values from the default connection is desirable, supply
139+
the name explicitly.
140+
"""
126141
# note we don't call super here because asyncio can not/is not recommended
127142
# to perform async operation in the __init__ while in the sync connection we
128143
# perform connect
@@ -173,7 +188,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
173188

174189
async def __aenter__(self) -> SnowflakeConnection:
175190
"""Context manager."""
176-
await self.connect()
191+
# Idempotent __aenter__ - required to be able to use both:
192+
# - with snowflake.connector.aio.SnowflakeConnection(**k)
193+
# - with snowflake.connector.aio.connect(**k)
194+
if self.is_closed():
195+
await self.connect()
177196
return self
178197

179198
async def __aexit__(

test/integ/aio_it/conftest.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
get_db_parameters,
1010
is_public_testaccount,
1111
)
12-
from typing import AsyncContextManager, AsyncGenerator, Callable
12+
from typing import Any, AsyncContextManager, AsyncGenerator, Callable
1313

1414
import pytest
1515

1616
from snowflake.connector.aio import SnowflakeConnection
17+
from snowflake.connector.aio import connect as async_connect
1718
from snowflake.connector.aio._telemetry import TelemetryClient
1819
from snowflake.connector.connection import DefaultConverterClass
1920
from snowflake.connector.telemetry import TelemetryData
@@ -70,13 +71,7 @@ def capture_sf_telemetry_async() -> TelemetryCaptureFixtureAsync:
7071
return TelemetryCaptureFixtureAsync()
7172

7273

73-
async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
74-
"""Creates a connection using the parameters defined in parameters.py.
75-
76-
You can select from the different connections by supplying the appropiate
77-
connection_name parameter and then anything else supplied will overwrite the values
78-
from parameters.py.
79-
"""
74+
def fill_conn_kwargs_for_tests(connection_name: str, **kwargs) -> dict[str, Any]:
8075
ret = get_db_parameters(connection_name)
8176
ret.update(kwargs)
8277

@@ -95,9 +90,18 @@ async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnecti
9590
ret.pop("private_key", None)
9691
ret.pop("private_key_file", None)
9792

98-
connection = SnowflakeConnection(**ret)
99-
await connection.connect()
100-
return connection
93+
return ret
94+
95+
96+
async def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
97+
"""Creates a connection using the parameters defined in parameters.py.
98+
99+
You can select from the different connections by supplying the appropiate
100+
connection_name parameter and then anything else supplied will overwrite the values
101+
from parameters.py.
102+
"""
103+
ret = fill_conn_kwargs_for_tests(connection_name, **kwargs)
104+
return await async_connect(**ret)
101105

102106

103107
@asynccontextmanager

test/integ/aio_it/test_connection_async.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
CONNECTION_PARAMETERS_ADMIN = {}
4747
from snowflake.connector.aio.auth import AuthByOkta, AuthByPlugin
4848

49-
from .conftest import create_connection
49+
from .conftest import create_connection, fill_conn_kwargs_for_tests
5050

5151
try:
5252
from snowflake.connector.errorcode import ER_FAILED_PROCESSING_QMARK
@@ -1466,6 +1466,64 @@ async def test_platform_detection_timeout(conn_cnx):
14661466
assert cnx.platform_detection_timeout_seconds == 2.5
14671467

14681468

1469+
@pytest.mark.skipolddriver
1470+
async def test_conn_cnx_basic(conn_cnx):
1471+
"""Tests platform detection timeout.
1472+
1473+
Creates a connection with platform_detection_timeout parameter.
1474+
"""
1475+
async with conn_cnx() as conn:
1476+
async with conn.cursor() as cur:
1477+
result = await (await cur.execute("select 1")).fetchall()
1478+
assert len(result) == 1
1479+
assert result[0][0] == 1
1480+
1481+
1482+
@pytest.mark.skipolddriver
1483+
async def test_conn_assigned_method(conn_cnx):
1484+
conn = await snowflake.connector.aio.connect(
1485+
**fill_conn_kwargs_for_tests("default")
1486+
)
1487+
async with conn.cursor() as cur:
1488+
result = await (await cur.execute("select 1")).fetchall()
1489+
assert len(result) == 1
1490+
assert result[0][0] == 1
1491+
1492+
1493+
@pytest.mark.skipolddriver
1494+
async def test_conn_assigned_class(conn_cnx):
1495+
conn = snowflake.connector.aio.SnowflakeConnection(
1496+
**fill_conn_kwargs_for_tests("default")
1497+
)
1498+
await conn.connect()
1499+
async with conn.cursor() as cur:
1500+
result = await (await cur.execute("select 1")).fetchall()
1501+
assert len(result) == 1
1502+
assert result[0][0] == 1
1503+
1504+
1505+
@pytest.mark.skipolddriver
1506+
async def test_conn_with_method(conn_cnx):
1507+
async with snowflake.connector.aio.connect(
1508+
**fill_conn_kwargs_for_tests("default")
1509+
) as conn:
1510+
async with conn.cursor() as cur:
1511+
result = await (await cur.execute("select 1")).fetchall()
1512+
assert len(result) == 1
1513+
assert result[0][0] == 1
1514+
1515+
1516+
@pytest.mark.skipolddriver
1517+
async def test_conn_with_class(conn_cnx):
1518+
async with snowflake.connector.aio.SnowflakeConnection(
1519+
**fill_conn_kwargs_for_tests("default")
1520+
) as conn:
1521+
async with conn.cursor() as cur:
1522+
result = await (await cur.execute("select 1")).fetchall()
1523+
assert len(result) == 1
1524+
assert result[0][0] == 1
1525+
1526+
14691527
@pytest.mark.skipolddriver
14701528
async def test_platform_detection_zero_timeout(conn_cnx):
14711529
with (

0 commit comments

Comments
 (0)