Skip to content

Commit 54acfae

Browse files
MarkDaoustcopybara-github
authored andcommitted
feat: [Python] add RegisterFiles so gcs files can be used with genai.
PiperOrigin-RevId: 839396637
1 parent 07c74dd commit 54acfae

File tree

5 files changed

+683
-18
lines changed

5 files changed

+683
-18
lines changed

google/genai/_api_client.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -984,12 +984,7 @@ def _access_token(self) -> str:
984984
self.project = project
985985

986986
if self._credentials:
987-
if self._credentials.expired or not self._credentials.token:
988-
# Only refresh when it needs to. Default expiration is 3600 seconds.
989-
refresh_auth(self._credentials)
990-
if not self._credentials.token:
991-
raise RuntimeError('Could not resolve API token from the environment')
992-
return self._credentials.token # type: ignore[no-any-return]
987+
return get_token_from_credentials(self, self._credentials) # type: ignore[no-any-return]
993988
else:
994989
raise RuntimeError('Could not resolve API token from the environment')
995990

@@ -1034,18 +1029,10 @@ async def _async_access_token(self) -> Union[str, Any]:
10341029
self.project = project
10351030

10361031
if self._credentials:
1037-
if self._credentials.expired or not self._credentials.token:
1038-
# Only refresh when it needs to. Default expiration is 3600 seconds.
1039-
async_auth_lock = await self._get_async_auth_lock()
1040-
async with async_auth_lock:
1041-
if self._credentials.expired or not self._credentials.token:
1042-
# Double check that the credentials expired before refreshing.
1043-
await asyncio.to_thread(refresh_auth, self._credentials)
1044-
1045-
if not self._credentials.token:
1046-
raise RuntimeError('Could not resolve API token from the environment')
1047-
1048-
return self._credentials.token
1032+
return await async_get_token_from_credentials(
1033+
self,
1034+
self._credentials
1035+
) # type: ignore[no-any-return]
10491036
else:
10501037
raise RuntimeError('Could not resolve API token from the environment')
10511038

@@ -1925,3 +1912,35 @@ def __del__(self) -> None:
19251912
asyncio.get_running_loop().create_task(self.aclose())
19261913
except Exception: # pylint: disable=broad-except
19271914
pass
1915+
1916+
def get_token_from_credentials(
1917+
client: 'BaseApiClient',
1918+
credentials: google.auth.credentials.Credentials
1919+
) -> str:
1920+
"""Refreshes the authentication token for the given credentials."""
1921+
if credentials.expired or not credentials.token:
1922+
# Only refresh when it needs to. Default expiration is 3600 seconds.
1923+
refresh_auth(credentials)
1924+
if not credentials.token:
1925+
raise RuntimeError('Could not resolve API token from the environment')
1926+
return credentials.token # type: ignore[no-any-return]
1927+
1928+
async def async_get_token_from_credentials(
1929+
client: 'BaseApiClient',
1930+
credentials: google.auth.credentials.Credentials
1931+
) -> str:
1932+
"""Refreshes the authentication token for the given credentials."""
1933+
if credentials.expired or not credentials.token:
1934+
# Only refresh when it needs to. Default expiration is 3600 seconds.
1935+
async_auth_lock = await client._get_async_auth_lock()
1936+
async with async_auth_lock:
1937+
if credentials.expired or not credentials.token:
1938+
# Double check that the credentials expired before refreshing.
1939+
await asyncio.to_thread(refresh_auth, credentials)
1940+
1941+
if not credentials.token:
1942+
raise RuntimeError('Could not resolve API token from the environment')
1943+
1944+
return credentials.token # type: ignore[no-any-return]
1945+
1946+

google/genai/files.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from typing import Any, Optional, Union
2323
from urllib.parse import urlencode
2424

25+
import google.auth
26+
27+
from . import _api_client
2528
from . import _api_module
2629
from . import _common
2730
from . import _extra_utils
@@ -149,6 +152,33 @@ def _ListFilesResponse_from_mldev(
149152
return to_object
150153

151154

155+
def _RegisterFilesParameters_to_mldev(
156+
from_object: Union[dict[str, Any], object],
157+
parent_object: Optional[dict[str, Any]] = None,
158+
) -> dict[str, Any]:
159+
to_object: dict[str, Any] = {}
160+
if getv(from_object, ['uris']) is not None:
161+
setv(to_object, ['uris'], getv(from_object, ['uris']))
162+
163+
return to_object
164+
165+
166+
def _RegisterFilesResponse_from_mldev(
167+
from_object: Union[dict[str, Any], object],
168+
parent_object: Optional[dict[str, Any]] = None,
169+
) -> dict[str, Any]:
170+
to_object: dict[str, Any] = {}
171+
if getv(from_object, ['sdkHttpResponse']) is not None:
172+
setv(
173+
to_object, ['sdk_http_response'], getv(from_object, ['sdkHttpResponse'])
174+
)
175+
176+
if getv(from_object, ['files']) is not None:
177+
setv(to_object, ['files'], [item for item in getv(from_object, ['files'])])
178+
179+
return to_object
180+
181+
152182
class Files(_api_module.BaseModule):
153183

154184
def _list(
@@ -402,6 +432,69 @@ def delete(
402432
self._api_client._verify_response(return_value)
403433
return return_value
404434

435+
def _register_files(
436+
self,
437+
*,
438+
uris: list[str],
439+
config: Optional[types.RegisterFilesConfigOrDict] = None,
440+
) -> types.RegisterFilesResponse:
441+
parameter_model = types._RegisterFilesParameters(
442+
uris=uris,
443+
config=config,
444+
)
445+
446+
request_url_dict: Optional[dict[str, str]]
447+
if self._api_client.vertexai:
448+
raise ValueError(
449+
'This method is only supported in the Gemini Developer client.'
450+
)
451+
else:
452+
request_dict = _RegisterFilesParameters_to_mldev(parameter_model)
453+
request_url_dict = request_dict.get('_url')
454+
if request_url_dict:
455+
path = 'files:register'.format_map(request_url_dict)
456+
else:
457+
path = 'files:register'
458+
459+
query_params = request_dict.get('_query')
460+
if query_params:
461+
path = f'{path}?{urlencode(query_params)}'
462+
# TODO: remove the hack that pops config.
463+
request_dict.pop('config', None)
464+
465+
http_options: Optional[types.HttpOptions] = None
466+
if (
467+
parameter_model.config is not None
468+
and parameter_model.config.http_options is not None
469+
):
470+
http_options = parameter_model.config.http_options
471+
472+
request_dict = _common.convert_to_dict(request_dict)
473+
request_dict = _common.encode_unserializable_types(request_dict)
474+
475+
response = self._api_client.request(
476+
'post', path, request_dict, http_options
477+
)
478+
479+
if config is not None and getattr(
480+
config, 'should_return_http_response', None
481+
):
482+
return_value = types.RegisterFilesResponse(sdk_http_response=response)
483+
self._api_client._verify_response(return_value)
484+
return return_value
485+
486+
response_dict = {} if not response.body else json.loads(response.body)
487+
488+
if not self._api_client.vertexai:
489+
response_dict = _RegisterFilesResponse_from_mldev(response_dict)
490+
491+
return_value = types.RegisterFilesResponse._from_response(
492+
response=response_dict, kwargs=parameter_model.model_dump()
493+
)
494+
495+
self._api_client._verify_response(return_value)
496+
return return_value
497+
405498
def upload(
406499
self,
407500
*,
@@ -559,6 +652,39 @@ def download(
559652

560653
return data
561654

655+
def register_files(
656+
self,
657+
*,
658+
auth: google.auth.credentials.Credentials,
659+
uris: list[str],
660+
config: Optional[types.RegisterFilesConfigOrDict] = None,
661+
) -> types.RegisterFilesResponse:
662+
"""Registers gcs files with the file service."""
663+
if not isinstance(auth, google.auth.credentials.Credentials):
664+
raise ValueError(
665+
'auth must be a google.auth.credentials.Credentials object.'
666+
)
667+
if config is None:
668+
config = types.RegisterFilesConfig()
669+
else:
670+
config = types.RegisterFilesConfig.model_validate(config)
671+
config = config.model_copy(deep=True)
672+
673+
http_options = config.http_options or types.HttpOptions()
674+
headers = http_options.headers or {}
675+
headers = {k.lower(): v for k, v in headers.items()}
676+
677+
token = _api_client.get_token_from_credentials(self._api_client, auth)
678+
headers['authorization'] = f'Bearer {token}'
679+
680+
if auth.quota_project_id:
681+
headers['x-goog-user-project'] = auth.quota_project_id
682+
683+
http_options.headers = headers
684+
config.http_options = http_options
685+
686+
return self._register_files(uris=uris, config=config)
687+
562688
def list(
563689
self, *, config: Optional[types.ListFilesConfigOrDict] = None
564690
) -> Pager[types.File]:
@@ -845,6 +971,69 @@ async def delete(
845971
self._api_client._verify_response(return_value)
846972
return return_value
847973

974+
async def _register_files(
975+
self,
976+
*,
977+
uris: list[str],
978+
config: Optional[types.RegisterFilesConfigOrDict] = None,
979+
) -> types.RegisterFilesResponse:
980+
parameter_model = types._RegisterFilesParameters(
981+
uris=uris,
982+
config=config,
983+
)
984+
985+
request_url_dict: Optional[dict[str, str]]
986+
if self._api_client.vertexai:
987+
raise ValueError(
988+
'This method is only supported in the Gemini Developer client.'
989+
)
990+
else:
991+
request_dict = _RegisterFilesParameters_to_mldev(parameter_model)
992+
request_url_dict = request_dict.get('_url')
993+
if request_url_dict:
994+
path = 'files:register'.format_map(request_url_dict)
995+
else:
996+
path = 'files:register'
997+
998+
query_params = request_dict.get('_query')
999+
if query_params:
1000+
path = f'{path}?{urlencode(query_params)}'
1001+
# TODO: remove the hack that pops config.
1002+
request_dict.pop('config', None)
1003+
1004+
http_options: Optional[types.HttpOptions] = None
1005+
if (
1006+
parameter_model.config is not None
1007+
and parameter_model.config.http_options is not None
1008+
):
1009+
http_options = parameter_model.config.http_options
1010+
1011+
request_dict = _common.convert_to_dict(request_dict)
1012+
request_dict = _common.encode_unserializable_types(request_dict)
1013+
1014+
response = await self._api_client.async_request(
1015+
'post', path, request_dict, http_options
1016+
)
1017+
1018+
if config is not None and getattr(
1019+
config, 'should_return_http_response', None
1020+
):
1021+
return_value = types.RegisterFilesResponse(sdk_http_response=response)
1022+
self._api_client._verify_response(return_value)
1023+
return return_value
1024+
1025+
response_dict = {} if not response.body else json.loads(response.body)
1026+
1027+
if not self._api_client.vertexai:
1028+
response_dict = _RegisterFilesResponse_from_mldev(response_dict)
1029+
1030+
return_value = types.RegisterFilesResponse._from_response(
1031+
response=response_dict, kwargs=parameter_model.model_dump()
1032+
)
1033+
1034+
self._api_client._verify_response(return_value)
1035+
return return_value
1036+
8481037
async def upload(
8491038
self,
8501039
*,
@@ -992,6 +1181,41 @@ async def download(
9921181

9931182
return data
9941183

1184+
async def register_files(
1185+
self,
1186+
*,
1187+
auth: google.auth.credentials.Credentials,
1188+
uris: list[str],
1189+
config: Optional[types.RegisterFilesConfigOrDict] = None,
1190+
) -> types.RegisterFilesResponse:
1191+
"""Registers gcs files with the file service."""
1192+
if not isinstance(auth, google.auth.credentials.Credentials):
1193+
raise ValueError(
1194+
'auth must be a google.auth.credentials.Credentials object.'
1195+
)
1196+
if config is None:
1197+
config = types.RegisterFilesConfig()
1198+
else:
1199+
config = types.RegisterFilesConfig.model_validate(config)
1200+
config = config.model_copy(deep=True)
1201+
1202+
http_options = config.http_options or types.HttpOptions()
1203+
headers = http_options.headers or {}
1204+
headers = {k.lower(): v for k, v in headers.items()}
1205+
1206+
token = await _api_client.async_get_token_from_credentials(
1207+
self._api_client, auth
1208+
)
1209+
headers['authorization'] = f'Bearer {token}'
1210+
1211+
if auth.quota_project_id:
1212+
headers['x-goog-user-project'] = auth.quota_project_id
1213+
1214+
http_options.headers = headers
1215+
config.http_options = http_options
1216+
1217+
return await self._register_files(uris=uris, config=config)
1218+
9951219
async def list(
9961220
self, *, config: Optional[types.ListFilesConfigOrDict] = None
9971221
) -> AsyncPager[types.File]:

0 commit comments

Comments
 (0)