|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from functools import wraps |
| 4 | +from typing import Any, Coroutine, Generator, Protocol, TypeVar, runtime_checkable |
| 5 | + |
3 | 6 | from ._connection import SnowflakeConnection |
4 | 7 | from ._cursor import DictCursor, SnowflakeCursor |
5 | 8 |
|
|
9 | 12 | DictCursor, |
10 | 13 | ] |
11 | 14 |
|
| 15 | +# ============================================================================ |
| 16 | +# DESIGN NOTES: |
| 17 | +# |
| 18 | +# Pattern similar to aiohttp.ClientSession.request() which similarly returns |
| 19 | +# an object that can be both awaited and used as an async context manager. |
| 20 | +# |
| 21 | +# The async connect function uses a wrapper to support both: |
| 22 | +# 1. Direct awaiting: conn = await connect(...) |
| 23 | +# 2. Async context manager: async with connect(...) as conn: |
| 24 | +# |
| 25 | +# connect: A function decorated with @wraps(SnowflakeConnection.__init__) that |
| 26 | +# preserves metadata for IDE support, type checking, and introspection. |
| 27 | +# Returns a _AsyncConnectContextManager instance when called. |
| 28 | +# |
| 29 | +# _AsyncConnectContextManager: Implements __await__ and __aenter__/__aexit__ |
| 30 | +# to support both patterns on the same awaitable. |
| 31 | +# |
| 32 | +# The @wraps decorator ensures that connect() has the same signature and |
| 33 | +# documentation as SnowflakeConnection.__init__, making it behave identically |
| 34 | +# to the sync snowflake.connector.connect function from an introspection POV. |
| 35 | +# |
| 36 | +# Metadata preservation is critical for IDE autocomplete, static type checkers, |
| 37 | +# and documentation generation to work correctly on the async connect function. |
| 38 | +# ============================================================================ |
| 39 | + |
| 40 | + |
| 41 | +T = TypeVar("T") |
| 42 | + |
| 43 | + |
| 44 | +@runtime_checkable |
| 45 | +class HybridCoroutineContextManager(Protocol[T]): |
| 46 | + """Protocol for a hybrid coroutine that is also an async context manager. |
| 47 | +
|
| 48 | + Combines the full coroutine protocol (PEP 492) with async context manager |
| 49 | + protocol (PEP 343/492), allowing code that expects either interface to work |
| 50 | + seamlessly with instances of this protocol. |
| 51 | +
|
| 52 | + This is used when external code needs to manage the coroutine lifecycle |
| 53 | + (e.g., timeout handlers, async schedulers) or use it as a context manager. |
| 54 | + """ |
| 55 | + |
| 56 | + # Full Coroutine Protocol (PEP 492) |
| 57 | + def send(self, __arg: Any) -> Any: |
| 58 | + """Send a value into the coroutine.""" |
| 59 | + ... |
| 60 | + |
| 61 | + def throw( |
| 62 | + self, |
| 63 | + __typ: type[BaseException], |
| 64 | + __val: BaseException | None = None, |
| 65 | + __tb: Any = None, |
| 66 | + ) -> Any: |
| 67 | + """Throw an exception into the coroutine.""" |
| 68 | + ... |
| 69 | + |
| 70 | + def close(self) -> None: |
| 71 | + """Close the coroutine.""" |
| 72 | + ... |
| 73 | + |
| 74 | + def __await__(self) -> Generator[Any, None, T]: |
| 75 | + """Return awaitable generator.""" |
| 76 | + ... |
| 77 | + |
| 78 | + def __iter__(self) -> Generator[Any, None, T]: |
| 79 | + """Iterate over the coroutine.""" |
| 80 | + ... |
| 81 | + |
| 82 | + # Async Context Manager Protocol (PEP 343) |
| 83 | + async def __aenter__(self) -> T: |
| 84 | + """Async context manager entry.""" |
| 85 | + ... |
| 86 | + |
| 87 | + async def __aexit__( |
| 88 | + self, |
| 89 | + __exc_type: type[BaseException] | None, |
| 90 | + __exc_val: BaseException | None, |
| 91 | + __exc_tb: Any, |
| 92 | + ) -> bool | None: |
| 93 | + """Async context manager exit.""" |
| 94 | + ... |
| 95 | + |
| 96 | + |
| 97 | +class _AsyncConnectContextManager(HybridCoroutineContextManager[SnowflakeConnection]): |
| 98 | + """Hybrid wrapper that enables both awaiting and async context manager usage. |
| 99 | +
|
| 100 | + Allows both patterns: |
| 101 | + - conn = await connect(...) |
| 102 | + - async with connect(...) as conn: |
| 103 | +
|
| 104 | + Implements the full coroutine protocol for maximum compatibility. |
| 105 | + Satisfies the HybridCoroutineContextManager protocol. |
| 106 | + """ |
| 107 | + |
| 108 | + __slots__ = ("_coro", "_conn") |
| 109 | + |
| 110 | + def __init__(self, coro: Coroutine[Any, Any, SnowflakeConnection]) -> None: |
| 111 | + self._coro = coro |
| 112 | + self._conn: SnowflakeConnection | None = None |
| 113 | + |
| 114 | + def send(self, arg: Any) -> Any: |
| 115 | + """Send a value into the wrapped coroutine.""" |
| 116 | + return self._coro.send(arg) |
| 117 | + |
| 118 | + def throw(self, *args: Any, **kwargs: Any) -> Any: |
| 119 | + """Throw an exception into the wrapped coroutine.""" |
| 120 | + return self._coro.throw(*args, **kwargs) |
| 121 | + |
| 122 | + def close(self) -> None: |
| 123 | + """Close the wrapped coroutine.""" |
| 124 | + return self._coro.close() |
| 125 | + |
| 126 | + def __await__(self) -> Generator[Any, None, SnowflakeConnection]: |
| 127 | + """Enable await connect(...)""" |
| 128 | + return self._coro.__await__() |
| 129 | + |
| 130 | + def __iter__(self) -> Generator[Any, None, SnowflakeConnection]: |
| 131 | + """Make the wrapper iterable like a coroutine.""" |
| 132 | + return self.__await__() |
| 133 | + |
| 134 | + # This approach requires idempotent __aenter__ of SnowflakeConnection class - so check if connected and do not repeat connecting |
| 135 | + async def __aenter__(self) -> SnowflakeConnection: |
| 136 | + """Enable async with connect(...) as conn:""" |
| 137 | + self._conn = await self._coro |
| 138 | + return await self._conn.__aenter__() |
| 139 | + |
| 140 | + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: |
| 141 | + """Exit async context manager.""" |
| 142 | + if self._conn is not None: |
| 143 | + return await self._conn.__aexit__(exc_type, exc, tb) |
| 144 | + else: |
| 145 | + return None |
| 146 | + |
| 147 | + |
| 148 | +@wraps(SnowflakeConnection.__init__) |
| 149 | +def connect(**kwargs: Any) -> HybridCoroutineContextManager[SnowflakeConnection]: |
| 150 | + """Create and connect to a Snowflake connection asynchronously. |
| 151 | +
|
| 152 | + Returns an awaitable that can also be used as an async context manager. |
| 153 | + Supports both patterns: |
| 154 | + - conn = await connect(...) |
| 155 | + - async with connect(...) as conn: |
| 156 | + """ |
| 157 | + |
| 158 | + async def _connect_coro() -> SnowflakeConnection: |
| 159 | + conn = SnowflakeConnection(**kwargs) |
| 160 | + await conn.connect() |
| 161 | + return conn |
12 | 162 |
|
13 | | -async def connect(**kwargs) -> SnowflakeConnection: |
14 | | - conn = SnowflakeConnection(**kwargs) |
15 | | - await conn.connect() |
16 | | - return conn |
| 163 | + return _AsyncConnectContextManager(_connect_coro()) |
0 commit comments