Skip to content

Commit 8075e60

Browse files
sararobcopybara-github
authored andcommitted
feat: GenAI SDK client - add context management to AsyncClient
PiperOrigin-RevId: 825009766
1 parent a52da0b commit 8075e60

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

tests/unit/vertexai/genai/test_genai_client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
import importlib
1919
import pytest
20+
from unittest import mock
2021

2122
from google.cloud import aiplatform
2223
import vertexai
24+
from vertexai._genai import client as vertexai_client
2325
from google.cloud.aiplatform import initializer as aiplatform_initializer
2426

2527

@@ -66,3 +68,28 @@ def test_live_client(self):
6668
def test_types(self):
6769
assert vertexai.types is not None
6870
assert vertexai.types.LLMMetric is not None
71+
72+
@pytest.mark.asyncio
73+
@pytest.mark.usefixtures("google_auth_mock")
74+
async def test_async_content_manager(self):
75+
with mock.patch.object(
76+
vertexai_client.AsyncClient, "aclose", autospec=True
77+
) as mock_aclose:
78+
async with vertexai.Client(
79+
project=_TEST_PROJECT, location=_TEST_LOCATION
80+
).aio as async_client:
81+
assert isinstance(async_client, vertexai_client.AsyncClient)
82+
83+
mock_aclose.assert_called_once()
84+
85+
@pytest.mark.asyncio
86+
@pytest.mark.usefixtures("google_auth_mock")
87+
async def test_call_aclose_async_client(self):
88+
with mock.patch.object(
89+
vertexai_client.AsyncClient, "aclose", autospec=True
90+
) as mock_aclose:
91+
async_client = vertexai.Client(
92+
project=_TEST_PROJECT, location=_TEST_LOCATION
93+
).aio
94+
await async_client.aclose()
95+
mock_aclose.assert_called_once()

vertexai/_genai/client.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414
#
1515

16+
import asyncio
1617
import importlib
1718
from typing import Optional, Union, Any
19+
from types import TracebackType
1820

1921
import google.auth
2022
from google.cloud.aiplatform import version as aip_version
@@ -48,7 +50,7 @@ def _add_tracking_headers(headers: dict[str, str]) -> None:
4850
class AsyncClient:
4951
"""Async Gen AI Client for the Vertex SDK."""
5052

51-
def __init__(self, api_client: genai_client.Client):
53+
def __init__(self, api_client: genai_client.BaseApiClient):
5254
self._api_client = api_client
5355
self._live = live.AsyncLive(self._api_client)
5456
self._evals = None
@@ -132,6 +134,40 @@ def datasets(self):
132134
)
133135
return self._datasets.AsyncDatasets(self._api_client)
134136

137+
async def aclose(self) -> None:
138+
"""Closes the async client explicitly.
139+
140+
Example usage:
141+
142+
from vertexai import Client
143+
144+
async_client = vertexai.Client(
145+
project='my-project-id', location='us-central1'
146+
).aio
147+
prompt_1 = await async_client.prompts.create(...)
148+
prompt_2 = await async_client.prompts.create(...)
149+
# Close the client to release resources.
150+
await async_client.aclose()
151+
"""
152+
await self._api_client.aclose()
153+
154+
async def __aenter__(self) -> "AsyncClient":
155+
return self
156+
157+
async def __aexit__(
158+
self,
159+
exc_type: Optional[Exception],
160+
exc_value: Optional[Exception],
161+
traceback: Optional[TracebackType],
162+
) -> None:
163+
await self.aclose()
164+
165+
def __del__(self) -> None:
166+
try:
167+
asyncio.get_running_loop().create_task(self.aclose())
168+
except Exception:
169+
pass
170+
135171

136172
class Client:
137173
"""Gen AI Client for the Vertex SDK.

0 commit comments

Comments
 (0)