Skip to content

Commit c615757

Browse files
GWealecopybara-github
authored andcommitted
fix: Add support for injecting a custom google.genai.Client into Gemini models
This change introduces a new `client` parameter to the `Gemini` model's constructor. When provided, this preconfigured `google.genai.Client` instance is used for all API calls, offering fine-grained control over authentication, project, and location settings Close #2560 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 874628604
1 parent 8c0bd20 commit c615757

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed

src/google/adk/models/google_llm.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,23 @@ class Gemini(BaseLlm):
8585
8686
Attributes:
8787
model: The name of the Gemini model.
88+
client: An optional preconfigured ``google.genai.Client`` instance.
89+
When provided, ADK uses this client for all API calls instead of
90+
creating one internally from environment variables or ADC. This
91+
allows fine-grained control over authentication, project, location,
92+
and other client-level settings — and enables running agents that
93+
target different Vertex AI regions within the same process.
94+
95+
Example::
96+
97+
from google import genai
98+
from google.adk.models import Gemini
99+
100+
client = genai.Client(
101+
vertexai=True, project="my-project", location="us-central1"
102+
)
103+
model = Gemini(model="gemini-2.5-flash", client=client)
104+
88105
use_interactions_api: Whether to use the interactions API for model
89106
invocation.
90107
"""
@@ -131,6 +148,35 @@ class Gemini(BaseLlm):
131148
```
132149
"""
133150

151+
def __init__(self, *, client: Optional[Client] = None, **kwargs: Any):
152+
"""Initialises a Gemini model wrapper.
153+
154+
Args:
155+
client: An optional preconfigured ``google.genai.Client``. When
156+
provided, ADK uses this client for **all** Gemini API calls
157+
(including the Live API) instead of creating one internally.
158+
159+
.. note::
160+
When a custom client is supplied it is used as-is for Live API
161+
connections. ADK will **not** override the client's
162+
``api_version``; you are responsible for setting the correct
163+
version (``v1beta1`` for Vertex AI, ``v1alpha`` for the
164+
Gemini developer API) on the client yourself.
165+
166+
.. warning::
167+
``google.genai.Client`` contains threading primitives that
168+
cannot be pickled. If you are deploying to Agent Engine (or
169+
any environment that serialises the model), do **not** pass a
170+
custom client — let ADK create one from the environment
171+
instead.
172+
173+
**kwargs: Forwarded to the Pydantic ``BaseLlm`` constructor
174+
(``model``, ``base_url``, ``retry_options``, etc.).
175+
"""
176+
super().__init__(**kwargs)
177+
# Store after super().__init__ so Pydantic validation runs first.
178+
object.__setattr__(self, '_client', client)
179+
134180
@classmethod
135181
@override
136182
def supported_models(cls) -> list[str]:
@@ -299,9 +345,16 @@ async def _generate_content_via_interactions(
299345
def api_client(self) -> Client:
300346
"""Provides the api client.
301347
348+
If a preconfigured ``client`` was passed to the constructor it is
349+
returned directly; otherwise a new ``Client`` is created using the
350+
default environment/ADC configuration.
351+
302352
Returns:
303353
The api client.
304354
"""
355+
if self._client is not None:
356+
return self._client
357+
305358
from google.genai import Client
306359

307360
return Client(
@@ -334,6 +387,9 @@ def _live_api_version(self) -> str:
334387

335388
@cached_property
336389
def _live_api_client(self) -> Client:
390+
if self._client is not None:
391+
return self._client
392+
337393
from google.genai import Client
338394

339395
return Client(

tests/unittests/models/test_google_llm.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,3 +2140,153 @@ async def __aexit__(self, *args):
21402140
# Verify the final speech_config is still None
21412141
assert config_arg.speech_config is None
21422142
assert isinstance(connection, GeminiLlmConnection)
2143+
2144+
2145+
# ---------------------------------------------------------------------------
2146+
# Tests for custom client injection (Issue #2560)
2147+
# ---------------------------------------------------------------------------
2148+
2149+
2150+
def test_custom_client_is_used_for_api_client():
2151+
"""When a custom client is provided, api_client returns it directly."""
2152+
from google.genai import Client
2153+
2154+
custom_client = mock.MagicMock(spec=Client)
2155+
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
2156+
2157+
assert gemini.api_client is custom_client
2158+
2159+
2160+
def test_custom_client_is_used_for_live_api_client():
2161+
"""When a custom client is provided, _live_api_client returns it directly."""
2162+
from google.genai import Client
2163+
2164+
custom_client = mock.MagicMock(spec=Client)
2165+
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
2166+
2167+
assert gemini._live_api_client is custom_client
2168+
2169+
2170+
def test_default_api_client_when_no_custom_client():
2171+
"""Without a custom client, api_client creates a default Client."""
2172+
gemini = Gemini(model="gemini-1.5-flash")
2173+
2174+
# api_client should construct a real Client (not None)
2175+
client = gemini.api_client
2176+
assert client is not None
2177+
# Verify it is not a mock — it's a real google.genai.Client
2178+
from google.genai import Client
2179+
2180+
assert isinstance(client, Client)
2181+
2182+
2183+
def test_default_live_api_client_when_no_custom_client():
2184+
"""Without a custom client, _live_api_client creates a default Client."""
2185+
gemini = Gemini(model="gemini-1.5-flash")
2186+
2187+
client = gemini._live_api_client
2188+
assert client is not None
2189+
from google.genai import Client
2190+
2191+
assert isinstance(client, Client)
2192+
2193+
2194+
def test_custom_client_api_backend_vertexai():
2195+
"""_api_backend reflects the custom client's vertexai setting."""
2196+
from google.genai import Client
2197+
2198+
custom_client = mock.MagicMock(spec=Client)
2199+
custom_client.vertexai = True
2200+
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
2201+
2202+
assert gemini._api_backend == GoogleLLMVariant.VERTEX_AI
2203+
2204+
2205+
def test_custom_client_api_backend_gemini_api():
2206+
"""_api_backend reflects non-vertexai custom client."""
2207+
from google.genai import Client
2208+
2209+
custom_client = mock.MagicMock(spec=Client)
2210+
custom_client.vertexai = False
2211+
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
2212+
2213+
assert gemini._api_backend == GoogleLLMVariant.GEMINI_API
2214+
2215+
2216+
@pytest.mark.asyncio
2217+
async def test_custom_client_used_for_generate_content():
2218+
"""Custom client is used when generate_content_async is called."""
2219+
from google.genai import Client
2220+
2221+
custom_client = mock.MagicMock(spec=Client)
2222+
custom_client.vertexai = False
2223+
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
2224+
2225+
generate_content_response = types.GenerateContentResponse(
2226+
candidates=[
2227+
types.Candidate(
2228+
content=Content(
2229+
role="model",
2230+
parts=[Part.from_text(text="Hello from custom client")],
2231+
),
2232+
finish_reason=types.FinishReason.STOP,
2233+
)
2234+
]
2235+
)
2236+
2237+
async def mock_coro():
2238+
return generate_content_response
2239+
2240+
custom_client.aio.models.generate_content.return_value = mock_coro()
2241+
2242+
llm_request = LlmRequest(
2243+
model="gemini-1.5-flash",
2244+
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
2245+
config=types.GenerateContentConfig(
2246+
system_instruction="You are a helpful assistant",
2247+
),
2248+
)
2249+
2250+
responses = [
2251+
resp
2252+
async for resp in gemini.generate_content_async(llm_request, stream=False)
2253+
]
2254+
2255+
assert len(responses) == 1
2256+
assert responses[0].content.parts[0].text == "Hello from custom client"
2257+
custom_client.aio.models.generate_content.assert_called_once()
2258+
2259+
2260+
@pytest.mark.asyncio
2261+
async def test_custom_client_used_for_live_connect():
2262+
"""Custom client is used for live API streaming connections."""
2263+
from google.genai import Client
2264+
2265+
custom_client = mock.MagicMock(spec=Client)
2266+
custom_client.vertexai = False
2267+
gemini = Gemini(model="gemini-1.5-flash", client=custom_client)
2268+
2269+
mock_live_session = mock.AsyncMock()
2270+
2271+
class MockLiveConnect:
2272+
2273+
async def __aenter__(self):
2274+
return mock_live_session
2275+
2276+
async def __aexit__(self, *args):
2277+
pass
2278+
2279+
custom_client.aio.live.connect.return_value = MockLiveConnect()
2280+
2281+
llm_request = LlmRequest(
2282+
model="gemini-1.5-flash",
2283+
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
2284+
config=types.GenerateContentConfig(
2285+
system_instruction="You are a helpful assistant",
2286+
),
2287+
)
2288+
llm_request.live_connect_config = types.LiveConnectConfig()
2289+
2290+
async with gemini.connect(llm_request) as connection:
2291+
custom_client.aio.live.connect.assert_called_once()
2292+
assert isinstance(connection, GeminiLlmConnection)

0 commit comments

Comments
 (0)