Skip to content

Commit 208e500

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 764358461
1 parent 47ab05a commit 208e500

File tree

4 files changed

+195
-5
lines changed

4 files changed

+195
-5
lines changed

owlbot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"noxfile.py",
8484
"testing",
8585
"docs/conf.py",
86+
"*.tar.gz"
8687
],
8788
)
8889
has_generator_updates = True

setup.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,8 @@
4949
profiler_extra_require = [
5050
"tensorboard-plugin-profile >= 2.4.0, <2.18.0", # <3.0.0",
5151
"werkzeug >= 2.0.0, <4.0.0",
52-
"tensorflow >=2.4.0, <3.0.0",
5352
]
54-
tensorboard_extra_require = [
55-
"tensorflow >=2.3.0, <3.0.0; python_version<='3.11'"
56-
] + profiler_extra_require
53+
tensorboard_extra_require = profiler_extra_require
5754

5855
metadata_extra_require = ["pandas >= 1.0.0", "numpy>=1.15.0"]
5956
xai_extra_require = ["tensorflow >=2.3.0, <3.0.0"]

tests/unit/vertexai/test_generative_models.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
from typing import Dict, Iterable, List, MutableSequence, Optional
2222
from unittest import mock
2323

24+
from google.api_core import operation as ga_operation
2425
import vertexai
2526
from google.cloud.aiplatform import initializer
2627
from google.cloud.aiplatform_v1 import types as types_v1
2728
from google.cloud.aiplatform_v1.services import (
2829
prediction_service as prediction_service_v1,
2930
)
3031
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
32+
from google.cloud.aiplatform_v1beta1.services import endpoint_service
3133
from vertexai import generative_models
3234
from vertexai.preview import (
3335
generative_models as preview_generative_models,
@@ -48,6 +50,7 @@
4850
)
4951
from vertexai.generative_models import _function_calling_utils
5052
from vertexai.caching import CachedContent
53+
from google.protobuf import field_mask_pb2
5154

5255

5356
_TEST_PROJECT = "test-project"
@@ -1710,6 +1713,115 @@ def test_defs_ref_renaming(self):
17101713
_fix_schema_dict_for_gapic_in_place(actual)
17111714
assert actual == expected
17121715

1716+
@pytest.mark.parametrize(
1717+
"generative_models",
1718+
[preview_generative_models], # Only preview supports set_logging_config
1719+
)
1720+
@mock.patch.object(endpoint_service.EndpointServiceClient, "update_endpoint")
1721+
def test_set_logging_config_for_endpoint(
1722+
self, mock_update_endpoint, generative_models: generative_models
1723+
):
1724+
endpoint_name = (
1725+
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/12345"
1726+
)
1727+
model = generative_models.GenerativeModel(endpoint_name)
1728+
1729+
mock_update_endpoint.return_value = types_v1beta1.Endpoint(name=endpoint_name)
1730+
1731+
enabled = True
1732+
sampling_rate = 0.5
1733+
bigquery_destination = f"bq://{_TEST_PROJECT}.my_dataset.my_table"
1734+
enable_otel_logging = True
1735+
1736+
model.set_request_response_logging_config(
1737+
enabled=enabled,
1738+
sampling_rate=sampling_rate,
1739+
bigquery_destination=bigquery_destination,
1740+
enable_otel_logging=enable_otel_logging,
1741+
)
1742+
1743+
expected_logging_config = types_v1beta1.PredictRequestResponseLoggingConfig(
1744+
enabled=enabled,
1745+
sampling_rate=sampling_rate,
1746+
bigquery_destination=types_v1beta1.BigQueryDestination(
1747+
output_uri=bigquery_destination
1748+
),
1749+
enable_otel_logging=enable_otel_logging,
1750+
)
1751+
expected_endpoint = types_v1beta1.Endpoint(
1752+
name=endpoint_name,
1753+
predict_request_response_logging_config=expected_logging_config,
1754+
)
1755+
expected_update_mask = field_mask_pb2.FieldMask(
1756+
paths=["predict_request_response_logging_config"]
1757+
)
1758+
1759+
mock_update_endpoint.assert_called_once_with(
1760+
types_v1beta1.UpdateEndpointRequest(
1761+
endpoint=expected_endpoint,
1762+
update_mask=expected_update_mask,
1763+
)
1764+
)
1765+
1766+
@pytest.mark.parametrize(
1767+
"generative_models",
1768+
[preview_generative_models], # Only preview supports set_logging_config
1769+
)
1770+
@mock.patch.object(
1771+
endpoint_service.EndpointServiceClient, "set_publisher_model_config"
1772+
)
1773+
def test_set_logging_config_for_publisher_model(
1774+
self, mock_set_publisher_model_config, generative_models: generative_models
1775+
):
1776+
model_name = "gemini-pro"
1777+
model = generative_models.GenerativeModel(model_name)
1778+
full_model_name = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/{model_name}"
1779+
1780+
enabled = False
1781+
sampling_rate = 1.0
1782+
bigquery_destination = f"bq://{_TEST_PROJECT}.another_dataset"
1783+
enable_otel_logging = False
1784+
1785+
mock_operation = mock.Mock(spec=ga_operation.Operation)
1786+
mock_set_publisher_model_config.return_value = mock_operation
1787+
mock_operation.result.return_value = types_v1beta1.PublisherModelConfig(
1788+
logging_config=types_v1beta1.PredictRequestResponseLoggingConfig(
1789+
enabled=enabled,
1790+
sampling_rate=sampling_rate,
1791+
bigquery_destination=types_v1beta1.BigQueryDestination(
1792+
output_uri=bigquery_destination
1793+
),
1794+
enable_otel_logging=enable_otel_logging,
1795+
)
1796+
)
1797+
1798+
model.set_request_response_logging_config(
1799+
enabled=enabled,
1800+
sampling_rate=sampling_rate,
1801+
bigquery_destination=bigquery_destination,
1802+
enable_otel_logging=enable_otel_logging,
1803+
)
1804+
1805+
expected_logging_config = types_v1beta1.PredictRequestResponseLoggingConfig(
1806+
enabled=enabled,
1807+
sampling_rate=sampling_rate,
1808+
bigquery_destination=types_v1beta1.BigQueryDestination(
1809+
output_uri=bigquery_destination
1810+
),
1811+
enable_otel_logging=enable_otel_logging,
1812+
)
1813+
expected_publisher_model_config = types_v1beta1.PublisherModelConfig(
1814+
logging_config=expected_logging_config
1815+
)
1816+
1817+
mock_set_publisher_model_config.assert_called_once_with(
1818+
types_v1beta1.SetPublisherModelConfigRequest(
1819+
name=full_model_name,
1820+
publisher_model_config=expected_publisher_model_config,
1821+
)
1822+
)
1823+
mock_operation.result.assert_called_once()
1824+
17131825

17141826
EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER = {
17151827
"title": "get_current_weather",

vertexai/generative_models/_generative_models.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Google LLC
1+
# Copyright 2025 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -48,6 +48,7 @@
4848
llm_utility_service as llm_utility_service_v1,
4949
)
5050
from google.cloud.aiplatform_v1beta1 import types as aiplatform_types
51+
from google.cloud.aiplatform_v1beta1.services import endpoint_service
5152
from google.cloud.aiplatform_v1beta1.services import prediction_service
5253
from google.cloud.aiplatform_v1beta1.services import llm_utility_service
5354
from google.cloud.aiplatform_v1beta1.types import (
@@ -59,6 +60,7 @@
5960
)
6061
from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types
6162
from google.protobuf import json_format
63+
from google.protobuf import field_mask_pb2
6264
import warnings
6365

6466
if TYPE_CHECKING:
@@ -502,6 +504,19 @@ def _llm_utility_client(self) -> llm_utility_service.LlmUtilityServiceClient:
502504
api_key=api_key,
503505
)
504506

507+
@functools.cached_property
508+
def _endpoint_client(self) -> endpoint_service.EndpointServiceClient:
509+
# Note this doesn't work with GCP Express but it's better to set the
510+
# client correctly and allow the service to throw
511+
api_key = aiplatform_initializer.global_config.api_key
512+
if api_key and aiplatform_initializer.global_config.project:
513+
api_key = None
514+
return aiplatform_initializer.global_config.create_client(
515+
client_class=endpoint_service.EndpointServiceClient,
516+
location_override=self._location,
517+
api_key=api_key,
518+
)
519+
505520
@functools.cached_property
506521
def _llm_utility_async_client(
507522
self,
@@ -3612,3 +3627,68 @@ def start_chat(
36123627
response_validation=response_validation,
36133628
responder=responder,
36143629
)
3630+
3631+
def set_request_response_logging_config(
3632+
self,
3633+
*,
3634+
enabled: bool,
3635+
sampling_rate: float,
3636+
bigquery_destination: str,
3637+
enable_otel_logging: Optional[bool] = None,
3638+
) -> Union[aiplatform_types.PublisherModelConfig, aiplatform_types.Endpoint]:
3639+
"""
3640+
Sets the request/response logging config.
3641+
3642+
Args:
3643+
enabled: If logging is enabled or not.
3644+
sampling_rate: Percentage of requests to be logged, expressed as a
3645+
fraction in range(0,1].
3646+
bigquery_destination: BigQuery table for logging. If only given a project,
3647+
a new dataset will be created with name
3648+
``logging_<endpoint-display-name>_<endpoint-id>`` where will
3649+
be made BigQuery-dataset-name compatible (e.g. most special
3650+
characters will become underscores). If no table name is
3651+
given, a new table will be created with name
3652+
``request_response_logging``
3653+
enable_otel_logging: This field is used for large models. If true, in
3654+
addition to the original large model logs, logs will be converted in
3655+
OTel schema format, and saved in otel_log column. Default
3656+
value is false.
3657+
Returns:
3658+
The updated PublisherModelConfig or Endpoint.
3659+
"""
3660+
3661+
logging_config = aiplatform_types.PredictRequestResponseLoggingConfig(
3662+
enabled=enabled,
3663+
sampling_rate=sampling_rate,
3664+
bigquery_destination=aiplatform_types.BigQueryDestination(
3665+
output_uri=bigquery_destination
3666+
),
3667+
enable_otel_logging=enable_otel_logging,
3668+
)
3669+
3670+
if self._endpoint_client.parse_endpoint_path(self._prediction_resource_name):
3671+
return self._endpoint_client.update_endpoint(
3672+
aiplatform_types.UpdateEndpointRequest(
3673+
endpoint=aiplatform_types.Endpoint(
3674+
name=self._prediction_resource_name,
3675+
predict_request_response_logging_config=logging_config,
3676+
),
3677+
update_mask=field_mask_pb2.FieldMask(
3678+
paths=["predict_request_response_logging_config"]
3679+
),
3680+
)
3681+
)
3682+
3683+
else:
3684+
3685+
operation = self._endpoint_client.set_publisher_model_config(
3686+
aiplatform_types.SetPublisherModelConfigRequest(
3687+
name=self._prediction_resource_name,
3688+
publisher_model_config=aiplatform_types.PublisherModelConfig(
3689+
logging_config=logging_config
3690+
),
3691+
)
3692+
)
3693+
3694+
return operation.result()

0 commit comments

Comments
 (0)