|
22 | 22 | from typing import Any, Optional, Union |
23 | 23 | from urllib.parse import urlencode |
24 | 24 |
|
| 25 | +import google.auth |
| 26 | + |
| 27 | +from . import _api_client |
25 | 28 | from . import _api_module |
26 | 29 | from . import _common |
27 | 30 | from . import _extra_utils |
@@ -149,6 +152,33 @@ def _ListFilesResponse_from_mldev( |
149 | 152 | return to_object |
150 | 153 |
|
151 | 154 |
|
| 155 | +def _RegisterFilesParameters_to_mldev( |
| 156 | + from_object: Union[dict[str, Any], object], |
| 157 | + parent_object: Optional[dict[str, Any]] = None, |
| 158 | +) -> dict[str, Any]: |
| 159 | + to_object: dict[str, Any] = {} |
| 160 | + if getv(from_object, ['uris']) is not None: |
| 161 | + setv(to_object, ['uris'], getv(from_object, ['uris'])) |
| 162 | + |
| 163 | + return to_object |
| 164 | + |
| 165 | + |
| 166 | +def _RegisterFilesResponse_from_mldev( |
| 167 | + from_object: Union[dict[str, Any], object], |
| 168 | + parent_object: Optional[dict[str, Any]] = None, |
| 169 | +) -> dict[str, Any]: |
| 170 | + to_object: dict[str, Any] = {} |
| 171 | + if getv(from_object, ['sdkHttpResponse']) is not None: |
| 172 | + setv( |
| 173 | + to_object, ['sdk_http_response'], getv(from_object, ['sdkHttpResponse']) |
| 174 | + ) |
| 175 | + |
| 176 | + if getv(from_object, ['files']) is not None: |
| 177 | + setv(to_object, ['files'], [item for item in getv(from_object, ['files'])]) |
| 178 | + |
| 179 | + return to_object |
| 180 | + |
| 181 | + |
152 | 182 | class Files(_api_module.BaseModule): |
153 | 183 |
|
154 | 184 | def _list( |
@@ -402,6 +432,69 @@ def delete( |
402 | 432 | self._api_client._verify_response(return_value) |
403 | 433 | return return_value |
404 | 434 |
|
| 435 | + def _register_files( |
| 436 | + self, |
| 437 | + *, |
| 438 | + uris: list[str], |
| 439 | + config: Optional[types.RegisterFilesConfigOrDict] = None, |
| 440 | + ) -> types.RegisterFilesResponse: |
| 441 | + parameter_model = types._RegisterFilesParameters( |
| 442 | + uris=uris, |
| 443 | + config=config, |
| 444 | + ) |
| 445 | + |
| 446 | + request_url_dict: Optional[dict[str, str]] |
| 447 | + if self._api_client.vertexai: |
| 448 | + raise ValueError( |
| 449 | + 'This method is only supported in the Gemini Developer client.' |
| 450 | + ) |
| 451 | + else: |
| 452 | + request_dict = _RegisterFilesParameters_to_mldev(parameter_model) |
| 453 | + request_url_dict = request_dict.get('_url') |
| 454 | + if request_url_dict: |
| 455 | + path = 'files:register'.format_map(request_url_dict) |
| 456 | + else: |
| 457 | + path = 'files:register' |
| 458 | + |
| 459 | + query_params = request_dict.get('_query') |
| 460 | + if query_params: |
| 461 | + path = f'{path}?{urlencode(query_params)}' |
| 462 | + # TODO: remove the hack that pops config. |
| 463 | + request_dict.pop('config', None) |
| 464 | + |
| 465 | + http_options: Optional[types.HttpOptions] = None |
| 466 | + if ( |
| 467 | + parameter_model.config is not None |
| 468 | + and parameter_model.config.http_options is not None |
| 469 | + ): |
| 470 | + http_options = parameter_model.config.http_options |
| 471 | + |
| 472 | + request_dict = _common.convert_to_dict(request_dict) |
| 473 | + request_dict = _common.encode_unserializable_types(request_dict) |
| 474 | + |
| 475 | + response = self._api_client.request( |
| 476 | + 'post', path, request_dict, http_options |
| 477 | + ) |
| 478 | + |
| 479 | + if config is not None and getattr( |
| 480 | + config, 'should_return_http_response', None |
| 481 | + ): |
| 482 | + return_value = types.RegisterFilesResponse(sdk_http_response=response) |
| 483 | + self._api_client._verify_response(return_value) |
| 484 | + return return_value |
| 485 | + |
| 486 | + response_dict = {} if not response.body else json.loads(response.body) |
| 487 | + |
| 488 | + if not self._api_client.vertexai: |
| 489 | + response_dict = _RegisterFilesResponse_from_mldev(response_dict) |
| 490 | + |
| 491 | + return_value = types.RegisterFilesResponse._from_response( |
| 492 | + response=response_dict, kwargs=parameter_model.model_dump() |
| 493 | + ) |
| 494 | + |
| 495 | + self._api_client._verify_response(return_value) |
| 496 | + return return_value |
| 497 | + |
405 | 498 | def upload( |
406 | 499 | self, |
407 | 500 | *, |
@@ -559,6 +652,39 @@ def download( |
559 | 652 |
|
560 | 653 | return data |
561 | 654 |
|
| 655 | + def register_files( |
| 656 | + self, |
| 657 | + *, |
| 658 | + auth: google.auth.credentials.Credentials, |
| 659 | + uris: list[str], |
| 660 | + config: Optional[types.RegisterFilesConfigOrDict] = None, |
| 661 | + ) -> types.RegisterFilesResponse: |
| 662 | + """Registers gcs files with the file service.""" |
| 663 | + if not isinstance(auth, google.auth.credentials.Credentials): |
| 664 | + raise ValueError( |
| 665 | + 'auth must be a google.auth.credentials.Credentials object.' |
| 666 | + ) |
| 667 | + if config is None: |
| 668 | + config = types.RegisterFilesConfig() |
| 669 | + else: |
| 670 | + config = types.RegisterFilesConfig.model_validate(config) |
| 671 | + config = config.model_copy(deep=True) |
| 672 | + |
| 673 | + http_options = config.http_options or types.HttpOptions() |
| 674 | + headers = http_options.headers or {} |
| 675 | + headers = {k.lower(): v for k, v in headers.items()} |
| 676 | + |
| 677 | + token = _api_client.get_token_from_credentials(self._api_client, auth) |
| 678 | + headers['authorization'] = f'Bearer {token}' |
| 679 | + |
| 680 | + if auth.quota_project_id: |
| 681 | + headers['x-goog-user-project'] = auth.quota_project_id |
| 682 | + |
| 683 | + http_options.headers = headers |
| 684 | + config.http_options = http_options |
| 685 | + |
| 686 | + return self._register_files(uris=uris, config=config) |
| 687 | + |
562 | 688 | def list( |
563 | 689 | self, *, config: Optional[types.ListFilesConfigOrDict] = None |
564 | 690 | ) -> Pager[types.File]: |
@@ -845,6 +971,69 @@ async def delete( |
845 | 971 | self._api_client._verify_response(return_value) |
846 | 972 | return return_value |
847 | 973 |
|
| 974 | + async def _register_files( |
| 975 | + self, |
| 976 | + *, |
| 977 | + uris: list[str], |
| 978 | + config: Optional[types.RegisterFilesConfigOrDict] = None, |
| 979 | + ) -> types.RegisterFilesResponse: |
| 980 | + parameter_model = types._RegisterFilesParameters( |
| 981 | + uris=uris, |
| 982 | + config=config, |
| 983 | + ) |
| 984 | + |
| 985 | + request_url_dict: Optional[dict[str, str]] |
| 986 | + if self._api_client.vertexai: |
| 987 | + raise ValueError( |
| 988 | + 'This method is only supported in the Gemini Developer client.' |
| 989 | + ) |
| 990 | + else: |
| 991 | + request_dict = _RegisterFilesParameters_to_mldev(parameter_model) |
| 992 | + request_url_dict = request_dict.get('_url') |
| 993 | + if request_url_dict: |
| 994 | + path = 'files:register'.format_map(request_url_dict) |
| 995 | + else: |
| 996 | + path = 'files:register' |
| 997 | + |
| 998 | + query_params = request_dict.get('_query') |
| 999 | + if query_params: |
| 1000 | + path = f'{path}?{urlencode(query_params)}' |
| 1001 | + # TODO: remove the hack that pops config. |
| 1002 | + request_dict.pop('config', None) |
| 1003 | + |
| 1004 | + http_options: Optional[types.HttpOptions] = None |
| 1005 | + if ( |
| 1006 | + parameter_model.config is not None |
| 1007 | + and parameter_model.config.http_options is not None |
| 1008 | + ): |
| 1009 | + http_options = parameter_model.config.http_options |
| 1010 | + |
| 1011 | + request_dict = _common.convert_to_dict(request_dict) |
| 1012 | + request_dict = _common.encode_unserializable_types(request_dict) |
| 1013 | + |
| 1014 | + response = await self._api_client.async_request( |
| 1015 | + 'post', path, request_dict, http_options |
| 1016 | + ) |
| 1017 | + |
| 1018 | + if config is not None and getattr( |
| 1019 | + config, 'should_return_http_response', None |
| 1020 | + ): |
| 1021 | + return_value = types.RegisterFilesResponse(sdk_http_response=response) |
| 1022 | + self._api_client._verify_response(return_value) |
| 1023 | + return return_value |
| 1024 | + |
| 1025 | + response_dict = {} if not response.body else json.loads(response.body) |
| 1026 | + |
| 1027 | + if not self._api_client.vertexai: |
| 1028 | + response_dict = _RegisterFilesResponse_from_mldev(response_dict) |
| 1029 | + |
| 1030 | + return_value = types.RegisterFilesResponse._from_response( |
| 1031 | + response=response_dict, kwargs=parameter_model.model_dump() |
| 1032 | + ) |
| 1033 | + |
| 1034 | + self._api_client._verify_response(return_value) |
| 1035 | + return return_value |
| 1036 | + |
848 | 1037 | async def upload( |
849 | 1038 | self, |
850 | 1039 | *, |
@@ -992,6 +1181,41 @@ async def download( |
992 | 1181 |
|
993 | 1182 | return data |
994 | 1183 |
|
| 1184 | + async def register_files( |
| 1185 | + self, |
| 1186 | + *, |
| 1187 | + auth: google.auth.credentials.Credentials, |
| 1188 | + uris: list[str], |
| 1189 | + config: Optional[types.RegisterFilesConfigOrDict] = None, |
| 1190 | + ) -> types.RegisterFilesResponse: |
| 1191 | + """Registers gcs files with the file service.""" |
| 1192 | + if not isinstance(auth, google.auth.credentials.Credentials): |
| 1193 | + raise ValueError( |
| 1194 | + 'auth must be a google.auth.credentials.Credentials object.' |
| 1195 | + ) |
| 1196 | + if config is None: |
| 1197 | + config = types.RegisterFilesConfig() |
| 1198 | + else: |
| 1199 | + config = types.RegisterFilesConfig.model_validate(config) |
| 1200 | + config = config.model_copy(deep=True) |
| 1201 | + |
| 1202 | + http_options = config.http_options or types.HttpOptions() |
| 1203 | + headers = http_options.headers or {} |
| 1204 | + headers = {k.lower(): v for k, v in headers.items()} |
| 1205 | + |
| 1206 | + token = await _api_client.async_get_token_from_credentials( |
| 1207 | + self._api_client, auth |
| 1208 | + ) |
| 1209 | + headers['authorization'] = f'Bearer {token}' |
| 1210 | + |
| 1211 | + if auth.quota_project_id: |
| 1212 | + headers['x-goog-user-project'] = auth.quota_project_id |
| 1213 | + |
| 1214 | + http_options.headers = headers |
| 1215 | + config.http_options = http_options |
| 1216 | + |
| 1217 | + return await self._register_files(uris=uris, config=config) |
| 1218 | + |
995 | 1219 | async def list( |
996 | 1220 | self, *, config: Optional[types.ListFilesConfigOrDict] = None |
997 | 1221 | ) -> AsyncPager[types.File]: |
|
0 commit comments