@@ -2140,3 +2140,153 @@ async def __aexit__(self, *args):
21402140 # Verify the final speech_config is still None
21412141 assert config_arg .speech_config is None
21422142 assert isinstance (connection , GeminiLlmConnection )
2143+
2144+
2145+ # ---------------------------------------------------------------------------
2146+ # Tests for custom client injection (Issue #2560)
2147+ # ---------------------------------------------------------------------------
2148+
2149+
2150+ def test_custom_client_is_used_for_api_client ():
2151+ """When a custom client is provided, api_client returns it directly."""
2152+ from google .genai import Client
2153+
2154+ custom_client = mock .MagicMock (spec = Client )
2155+ gemini = Gemini (model = "gemini-1.5-flash" , client = custom_client )
2156+
2157+ assert gemini .api_client is custom_client
2158+
2159+
2160+ def test_custom_client_is_used_for_live_api_client ():
2161+ """When a custom client is provided, _live_api_client returns it directly."""
2162+ from google .genai import Client
2163+
2164+ custom_client = mock .MagicMock (spec = Client )
2165+ gemini = Gemini (model = "gemini-1.5-flash" , client = custom_client )
2166+
2167+ assert gemini ._live_api_client is custom_client
2168+
2169+
2170+ def test_default_api_client_when_no_custom_client ():
2171+ """Without a custom client, api_client creates a default Client."""
2172+ gemini = Gemini (model = "gemini-1.5-flash" )
2173+
2174+ # api_client should construct a real Client (not None)
2175+ client = gemini .api_client
2176+ assert client is not None
2177+ # Verify it is not a mock — it's a real google.genai.Client
2178+ from google .genai import Client
2179+
2180+ assert isinstance (client , Client )
2181+
2182+
2183+ def test_default_live_api_client_when_no_custom_client ():
2184+ """Without a custom client, _live_api_client creates a default Client."""
2185+ gemini = Gemini (model = "gemini-1.5-flash" )
2186+
2187+ client = gemini ._live_api_client
2188+ assert client is not None
2189+ from google .genai import Client
2190+
2191+ assert isinstance (client , Client )
2192+
2193+
2194+ def test_custom_client_api_backend_vertexai ():
2195+ """_api_backend reflects the custom client's vertexai setting."""
2196+ from google .genai import Client
2197+
2198+ custom_client = mock .MagicMock (spec = Client )
2199+ custom_client .vertexai = True
2200+ gemini = Gemini (model = "gemini-1.5-flash" , client = custom_client )
2201+
2202+ assert gemini ._api_backend == GoogleLLMVariant .VERTEX_AI
2203+
2204+
2205+ def test_custom_client_api_backend_gemini_api ():
2206+ """_api_backend reflects non-vertexai custom client."""
2207+ from google .genai import Client
2208+
2209+ custom_client = mock .MagicMock (spec = Client )
2210+ custom_client .vertexai = False
2211+ gemini = Gemini (model = "gemini-1.5-flash" , client = custom_client )
2212+
2213+ assert gemini ._api_backend == GoogleLLMVariant .GEMINI_API
2214+
2215+
2216+ @pytest .mark .asyncio
2217+ async def test_custom_client_used_for_generate_content ():
2218+ """Custom client is used when generate_content_async is called."""
2219+ from google .genai import Client
2220+
2221+ custom_client = mock .MagicMock (spec = Client )
2222+ custom_client .vertexai = False
2223+ gemini = Gemini (model = "gemini-1.5-flash" , client = custom_client )
2224+
2225+ generate_content_response = types .GenerateContentResponse (
2226+ candidates = [
2227+ types .Candidate (
2228+ content = Content (
2229+ role = "model" ,
2230+ parts = [Part .from_text (text = "Hello from custom client" )],
2231+ ),
2232+ finish_reason = types .FinishReason .STOP ,
2233+ )
2234+ ]
2235+ )
2236+
2237+ async def mock_coro ():
2238+ return generate_content_response
2239+
2240+ custom_client .aio .models .generate_content .return_value = mock_coro ()
2241+
2242+ llm_request = LlmRequest (
2243+ model = "gemini-1.5-flash" ,
2244+ contents = [Content (role = "user" , parts = [Part .from_text (text = "Hello" )])],
2245+ config = types .GenerateContentConfig (
2246+ system_instruction = "You are a helpful assistant" ,
2247+ ),
2248+ )
2249+
2250+ responses = [
2251+ resp
2252+ async for resp in gemini .generate_content_async (llm_request , stream = False )
2253+ ]
2254+
2255+ assert len (responses ) == 1
2256+ assert responses [0 ].content .parts [0 ].text == "Hello from custom client"
2257+ custom_client .aio .models .generate_content .assert_called_once ()
2258+
2259+
2260+ @pytest .mark .asyncio
2261+ async def test_custom_client_used_for_live_connect ():
2262+ """Custom client is used for live API streaming connections."""
2263+ from google .genai import Client
2264+
2265+ custom_client = mock .MagicMock (spec = Client )
2266+ custom_client .vertexai = False
2267+ gemini = Gemini (model = "gemini-1.5-flash" , client = custom_client )
2268+
2269+ mock_live_session = mock .AsyncMock ()
2270+
2271+ class MockLiveConnect :
2272+
2273+ async def __aenter__ (self ):
2274+ return mock_live_session
2275+
2276+ async def __aexit__ (self , * args ):
2277+ pass
2278+
2279+ custom_client .aio .live .connect .return_value = MockLiveConnect ()
2280+
2281+ llm_request = LlmRequest (
2282+ model = "gemini-1.5-flash" ,
2283+ contents = [Content (role = "user" , parts = [Part .from_text (text = "Hello" )])],
2284+ config = types .GenerateContentConfig (
2285+ system_instruction = "You are a helpful assistant" ,
2286+ ),
2287+ )
2288+ llm_request .live_connect_config = types .LiveConnectConfig ()
2289+
2290+ async with gemini .connect (llm_request ) as connection :
2291+ custom_client .aio .live .connect .assert_called_once ()
2292+ assert isinstance (connection , GeminiLlmConnection )
0 commit comments