|
21 | 21 | from typing import Dict, Iterable, List, MutableSequence, Optional
|
22 | 22 | from unittest import mock
|
23 | 23 |
|
| 24 | +from google.api_core import operation as ga_operation |
24 | 25 | import vertexai
|
25 | 26 | from google.cloud.aiplatform import initializer
|
26 | 27 | from google.cloud.aiplatform_v1 import types as types_v1
|
27 | 28 | from google.cloud.aiplatform_v1.services import (
|
28 | 29 | prediction_service as prediction_service_v1,
|
29 | 30 | )
|
30 | 31 | from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
|
| 32 | +from google.cloud.aiplatform_v1beta1.services import endpoint_service |
31 | 33 | from vertexai import generative_models
|
32 | 34 | from vertexai.preview import (
|
33 | 35 | generative_models as preview_generative_models,
|
|
48 | 50 | )
|
49 | 51 | from vertexai.generative_models import _function_calling_utils
|
50 | 52 | from vertexai.caching import CachedContent
|
| 53 | +from google.protobuf import field_mask_pb2 |
51 | 54 |
|
52 | 55 |
|
53 | 56 | _TEST_PROJECT = "test-project"
|
@@ -1710,6 +1713,115 @@ def test_defs_ref_renaming(self):
|
1710 | 1713 | _fix_schema_dict_for_gapic_in_place(actual)
|
1711 | 1714 | assert actual == expected
|
1712 | 1715 |
|
| 1716 | + @pytest.mark.parametrize( |
| 1717 | + "generative_models", |
| 1718 | + [preview_generative_models], # Only preview supports set_logging_config |
| 1719 | + ) |
| 1720 | + @mock.patch.object(endpoint_service.EndpointServiceClient, "update_endpoint") |
| 1721 | + def test_set_logging_config_for_endpoint( |
| 1722 | + self, mock_update_endpoint, generative_models: generative_models |
| 1723 | + ): |
| 1724 | + endpoint_name = ( |
| 1725 | + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/12345" |
| 1726 | + ) |
| 1727 | + model = generative_models.GenerativeModel(endpoint_name) |
| 1728 | + |
| 1729 | + mock_update_endpoint.return_value = types_v1beta1.Endpoint(name=endpoint_name) |
| 1730 | + |
| 1731 | + enabled = True |
| 1732 | + sampling_rate = 0.5 |
| 1733 | + bigquery_destination = f"bq://{_TEST_PROJECT}.my_dataset.my_table" |
| 1734 | + enable_otel_logging = True |
| 1735 | + |
| 1736 | + model.set_request_response_logging_config( |
| 1737 | + enabled=enabled, |
| 1738 | + sampling_rate=sampling_rate, |
| 1739 | + bigquery_destination=bigquery_destination, |
| 1740 | + enable_otel_logging=enable_otel_logging, |
| 1741 | + ) |
| 1742 | + |
| 1743 | + expected_logging_config = types_v1beta1.PredictRequestResponseLoggingConfig( |
| 1744 | + enabled=enabled, |
| 1745 | + sampling_rate=sampling_rate, |
| 1746 | + bigquery_destination=types_v1beta1.BigQueryDestination( |
| 1747 | + output_uri=bigquery_destination |
| 1748 | + ), |
| 1749 | + enable_otel_logging=enable_otel_logging, |
| 1750 | + ) |
| 1751 | + expected_endpoint = types_v1beta1.Endpoint( |
| 1752 | + name=endpoint_name, |
| 1753 | + predict_request_response_logging_config=expected_logging_config, |
| 1754 | + ) |
| 1755 | + expected_update_mask = field_mask_pb2.FieldMask( |
| 1756 | + paths=["predict_request_response_logging_config"] |
| 1757 | + ) |
| 1758 | + |
| 1759 | + mock_update_endpoint.assert_called_once_with( |
| 1760 | + types_v1beta1.UpdateEndpointRequest( |
| 1761 | + endpoint=expected_endpoint, |
| 1762 | + update_mask=expected_update_mask, |
| 1763 | + ) |
| 1764 | + ) |
| 1765 | + |
| 1766 | + @pytest.mark.parametrize( |
| 1767 | + "generative_models", |
| 1768 | + [preview_generative_models], # Only preview supports set_logging_config |
| 1769 | + ) |
| 1770 | + @mock.patch.object( |
| 1771 | + endpoint_service.EndpointServiceClient, "set_publisher_model_config" |
| 1772 | + ) |
| 1773 | + def test_set_logging_config_for_publisher_model( |
| 1774 | + self, mock_set_publisher_model_config, generative_models: generative_models |
| 1775 | + ): |
| 1776 | + model_name = "gemini-pro" |
| 1777 | + model = generative_models.GenerativeModel(model_name) |
| 1778 | + full_model_name = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/{model_name}" |
| 1779 | + |
| 1780 | + enabled = False |
| 1781 | + sampling_rate = 1.0 |
| 1782 | + bigquery_destination = f"bq://{_TEST_PROJECT}.another_dataset" |
| 1783 | + enable_otel_logging = False |
| 1784 | + |
| 1785 | + mock_operation = mock.Mock(spec=ga_operation.Operation) |
| 1786 | + mock_set_publisher_model_config.return_value = mock_operation |
| 1787 | + mock_operation.result.return_value = types_v1beta1.PublisherModelConfig( |
| 1788 | + logging_config=types_v1beta1.PredictRequestResponseLoggingConfig( |
| 1789 | + enabled=enabled, |
| 1790 | + sampling_rate=sampling_rate, |
| 1791 | + bigquery_destination=types_v1beta1.BigQueryDestination( |
| 1792 | + output_uri=bigquery_destination |
| 1793 | + ), |
| 1794 | + enable_otel_logging=enable_otel_logging, |
| 1795 | + ) |
| 1796 | + ) |
| 1797 | + |
| 1798 | + model.set_request_response_logging_config( |
| 1799 | + enabled=enabled, |
| 1800 | + sampling_rate=sampling_rate, |
| 1801 | + bigquery_destination=bigquery_destination, |
| 1802 | + enable_otel_logging=enable_otel_logging, |
| 1803 | + ) |
| 1804 | + |
| 1805 | + expected_logging_config = types_v1beta1.PredictRequestResponseLoggingConfig( |
| 1806 | + enabled=enabled, |
| 1807 | + sampling_rate=sampling_rate, |
| 1808 | + bigquery_destination=types_v1beta1.BigQueryDestination( |
| 1809 | + output_uri=bigquery_destination |
| 1810 | + ), |
| 1811 | + enable_otel_logging=enable_otel_logging, |
| 1812 | + ) |
| 1813 | + expected_publisher_model_config = types_v1beta1.PublisherModelConfig( |
| 1814 | + logging_config=expected_logging_config |
| 1815 | + ) |
| 1816 | + |
| 1817 | + mock_set_publisher_model_config.assert_called_once_with( |
| 1818 | + types_v1beta1.SetPublisherModelConfigRequest( |
| 1819 | + name=full_model_name, |
| 1820 | + publisher_model_config=expected_publisher_model_config, |
| 1821 | + ) |
| 1822 | + ) |
| 1823 | + mock_operation.result.assert_called_once() |
| 1824 | + |
1713 | 1825 |
|
1714 | 1826 | EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER = {
|
1715 | 1827 | "title": "get_current_weather",
|
|
0 commit comments