From e9eb159756dfe90c9f72818204fa74d05096aec6 Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 14 Sep 2023 14:06:47 -0700 Subject: [PATCH] feat: Added async prediction and explanation support to the `Endpoint` class * Added the `Endpoint.predict_async` method * Added the `Endpoint.explain_async` method * Made it possible to use async clients in classes derived from `VertexAiResourceNounWithFutureManager` that use `@optional_sync`. PiperOrigin-RevId: 565472250 --- google/cloud/aiplatform/compat/__init__.py | 6 + .../aiplatform/compat/services/__init__.py | 8 + google/cloud/aiplatform/models.py | 164 +++++++++++++++++- google/cloud/aiplatform/utils/__init__.py | 16 ++ .../aiplatform/test_model_interactions.py | 9 + tests/unit/aiplatform/test_endpoints.py | 98 +++++++++++ 6 files changed, 295 insertions(+), 6 deletions(-) diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py index f4d7548025..965b8928cf 100644 --- a/google/cloud/aiplatform/compat/__init__.py +++ b/google/cloud/aiplatform/compat/__init__.py @@ -39,6 +39,9 @@ services.model_garden_service_client = services.model_garden_service_client_v1beta1 services.pipeline_service_client = services.pipeline_service_client_v1beta1 services.prediction_service_client = services.prediction_service_client_v1beta1 + services.prediction_service_async_client = ( + services.prediction_service_async_client_v1beta1 + ) services.schedule_service_client = services.schedule_service_client_v1beta1 services.specialist_pool_service_client = ( services.specialist_pool_service_client_v1beta1 @@ -144,6 +147,9 @@ services.model_service_client = services.model_service_client_v1 services.pipeline_service_client = services.pipeline_service_client_v1 services.prediction_service_client = services.prediction_service_client_v1 + services.prediction_service_async_client = ( + services.prediction_service_async_client_v1 + ) services.schedule_service_client = services.schedule_service_client_v1 services.specialist_pool_service_client = services.specialist_pool_service_client_v1 services.tensorboard_service_client = services.tensorboard_service_client_v1 diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py index d2e464425b..5901b643cf 100644 --- a/google/cloud/aiplatform/compat/services/__init__.py +++ b/google/cloud/aiplatform/compat/services/__init__.py @@ -60,6 +60,9 @@ from google.cloud.aiplatform_v1beta1.services.prediction_service import ( client as prediction_service_client_v1beta1, ) +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + async_client as prediction_service_async_client_v1beta1, +) from google.cloud.aiplatform_v1beta1.services.schedule_service import ( client as schedule_service_client_v1beta1, ) @@ -109,6 +112,9 @@ from google.cloud.aiplatform_v1.services.prediction_service import ( client as prediction_service_client_v1, ) +from google.cloud.aiplatform_v1.services.prediction_service import ( + async_client as prediction_service_async_client_v1, +) from google.cloud.aiplatform_v1.services.schedule_service import ( client as schedule_service_client_v1, ) @@ -136,6 +142,7 @@ model_service_client_v1, pipeline_service_client_v1, prediction_service_client_v1, + prediction_service_async_client_v1, schedule_service_client_v1, specialist_pool_service_client_v1, tensorboard_service_client_v1, @@ -155,6 +162,7 @@ persistent_resource_service_client_v1beta1, pipeline_service_client_v1beta1, prediction_service_client_v1beta1, + prediction_service_async_client_v1beta1, schedule_service_client_v1beta1, specialist_pool_service_client_v1beta1, metadata_service_client_v1beta1, diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 2a755a5288..dde470b743 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import pathlib import re @@ -226,7 +227,10 @@ def __init__( # Lazy load the Endpoint gca_resource until needed self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name) - self._prediction_client = self._instantiate_prediction_client( + ( + self._prediction_client, + self._prediction_async_client, + ) = self._instantiate_prediction_clients( location=self.location, credentials=credentials, ) @@ -572,7 +576,10 @@ def _construct_sdk_resource_from_gapic( credentials=credentials, ) - endpoint._prediction_client = cls._instantiate_prediction_client( + ( + endpoint._prediction_client, + endpoint._prediction_async_client, + ) = cls._instantiate_prediction_clients( location=endpoint.location, credentials=credentials, ) @@ -1384,10 +1391,12 @@ def _undeploy( self._sync_gca_resource() @staticmethod - def _instantiate_prediction_client( + def _instantiate_prediction_clients( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - ) -> utils.PredictionClientWithOverride: + ) -> Tuple[ + utils.PredictionClientWithOverride, utils.PredictionAsyncClientWithOverride + ]: """Helper method to instantiates prediction client with optional overrides for this endpoint. @@ -1399,14 +1408,34 @@ def _instantiate_prediction_client( Returns: prediction_client (prediction_service_client.PredictionServiceClient): - Initialized prediction client with optional overrides. + prediction_async_client (PredictionServiceAsyncClient): + Initialized prediction clients with optional overrides. """ - return initializer.global_config.create_client( + + # Creating an event loop if needed. + # PredictionServiceAsyncClient constructor calls `asyncio.get_event_loop`, + # which fails when there is no event loop (which does not exist by default + # in non-main threads in thread pool used when `sync=False`). + try: + asyncio.get_event_loop() + except RuntimeError: + asyncio.set_event_loop(asyncio.new_event_loop()) + + async_client = initializer.global_config.create_client( + client_class=utils.PredictionAsyncClientWithOverride, + credentials=credentials, + location_override=location, + prediction_client=True, + ) + # We could use `client = async_client._client`, but then client would be + # a concrete `PredictionServiceClient`, not `PredictionClientWithOverride`. + client = initializer.global_config.create_client( client_class=utils.PredictionClientWithOverride, credentials=credentials, location_override=location, prediction_client=True, ) + return (client, async_client) def update( self, @@ -1581,6 +1610,65 @@ def predict( model_resource_name=prediction_response.model, ) + async def predict_async( + self, + instances: List, + *, + parameters: Optional[Dict] = None, + timeout: Optional[float] = None, + ) -> Prediction: + """Make an asynchronous prediction against this Endpoint. + Example usage: + ``` + response = await my_endpoint.predict_async(instances=[...]) + my_predictions = response.predictions + ``` + + Args: + instances (List): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + parameters (Dict): + Optional. The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + timeout (float): Optional. The timeout for this request in seconds. + + Returns: + prediction (aiplatform.Prediction): + Prediction with returned predictions and Model ID. + """ + self.wait() + + prediction_response = await self._prediction_async_client.predict( + endpoint=self._gca_resource.name, + instances=instances, + parameters=parameters, + timeout=timeout, + ) + + return Prediction( + predictions=[ + json_format.MessageToDict(item) + for item in prediction_response.predictions.pb + ], + deployed_model_id=prediction_response.deployed_model_id, + model_version_id=prediction_response.model_version_id, + model_resource_name=prediction_response.model, + ) + def raw_predict( self, body: bytes, headers: Dict[str, str] ) -> requests.models.Response: @@ -1676,6 +1764,70 @@ def explain( explanations=explain_response.explanations, ) + async def explain_async( + self, + instances: List[Dict], + *, + parameters: Optional[Dict] = None, + deployed_model_id: Optional[str] = None, + timeout: Optional[float] = None, + ) -> Prediction: + """Make a prediction with explanations against this Endpoint. + + Example usage: + ``` + response = await my_endpoint.explain_async(instances=[...]) + my_explanations = response.explanations + ``` + + Args: + instances (List): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + parameters (Dict): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + deployed_model_id (str): + Optional. If specified, this ExplainRequest will be served by the + chosen DeployedModel, overriding this Endpoint's traffic split. + timeout (float): Optional. The timeout for this request in seconds. + + Returns: + prediction (aiplatform.Prediction): + Prediction with returned predictions, explanations, and Model ID. + """ + self.wait() + + explain_response = await self._prediction_async_client.explain( + endpoint=self.resource_name, + instances=instances, + parameters=parameters, + deployed_model_id=deployed_model_id, + timeout=timeout, + ) + + return Prediction( + predictions=[ + json_format.MessageToDict(item) + for item in explain_response.predictions.pb + ], + deployed_model_id=explain_response.deployed_model_id, + explanations=explain_response.explanations, + ) + @classmethod def list( cls, diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py index a5dea7a112..65ff78a038 100644 --- a/google/cloud/aiplatform/utils/__init__.py +++ b/google/cloud/aiplatform/utils/__init__.py @@ -49,6 +49,7 @@ model_service_client_v1beta1, pipeline_service_client_v1beta1, prediction_service_client_v1beta1, + prediction_service_async_client_v1beta1, schedule_service_client_v1beta1, tensorboard_service_client_v1beta1, vizier_service_client_v1beta1, @@ -68,6 +69,7 @@ model_service_client_v1, pipeline_service_client_v1, prediction_service_client_v1, + prediction_service_async_client_v1, schedule_service_client_v1, tensorboard_service_client_v1, vizier_service_client_v1, @@ -89,6 +91,7 @@ index_endpoint_service_client_v1beta1.IndexEndpointServiceClient, model_service_client_v1beta1.ModelServiceClient, prediction_service_client_v1beta1.PredictionServiceClient, + prediction_service_async_client_v1beta1.PredictionServiceAsyncClient, pipeline_service_client_v1beta1.PipelineServiceClient, job_service_client_v1beta1.JobServiceClient, match_service_client_v1beta1.MatchServiceClient, @@ -104,6 +107,7 @@ metadata_service_client_v1.MetadataServiceClient, model_service_client_v1.ModelServiceClient, prediction_service_client_v1.PredictionServiceClient, + prediction_service_async_client_v1.PredictionServiceAsyncClient, pipeline_service_client_v1.PipelineServiceClient, job_service_client_v1.JobServiceClient, schedule_service_client_v1.ScheduleServiceClient, @@ -616,6 +620,18 @@ class PredictionClientWithOverride(ClientWithOverride): ) +class PredictionAsyncClientWithOverride(ClientWithOverride): + _is_temporary = False + _default_version = compat.DEFAULT_VERSION + _version_map = ( + (compat.V1, prediction_service_async_client_v1.PredictionServiceAsyncClient), + ( + compat.V1BETA1, + prediction_service_async_client_v1beta1.PredictionServiceAsyncClient, + ), + ) + + class MatchClientWithOverride(ClientWithOverride): _is_temporary = False _default_version = compat.V1BETA1 diff --git a/tests/system/aiplatform/test_model_interactions.py b/tests/system/aiplatform/test_model_interactions.py index dc222ab79e..5f24ccc753 100644 --- a/tests/system/aiplatform/test_model_interactions.py +++ b/tests/system/aiplatform/test_model_interactions.py @@ -16,6 +16,7 @@ # import json +import pytest from google.cloud import aiplatform @@ -64,3 +65,11 @@ def test_prediction(self): ) assert raw_prediction_response.status_code == 200 assert len(json.loads(raw_prediction_response.text)) == 1 + + @pytest.mark.asyncio + async def test_endpoint_predict_async(self): + # Test the Endpoint.predict_async method. + prediction_response = await self.endpoint.predict_async( + instances=[_PREDICTION_INSTANCE] + ) + assert len(prediction_response.predictions) == 1 diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index 8e39790676..2b162a5eac 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -43,6 +43,7 @@ endpoint_service_client, endpoint_service_client_v1beta1, prediction_service_client, + prediction_service_async_client, deployment_resource_pool_service_client_v1beta1, ) @@ -464,6 +465,22 @@ def predict_client_predict_mock(): yield predict_mock +@pytest.fixture +def predict_async_client_predict_mock(): + response = gca_prediction_service.PredictResponse( + deployed_model_id=_TEST_MODEL_ID, + model_version_id=_TEST_VERSION_ID, + model=_TEST_MODEL_NAME, + ) + response.predictions.extend(_TEST_PREDICTION) + with mock.patch.object( + target=prediction_service_async_client.PredictionServiceAsyncClient, + attribute="predict", + return_value=response, + ) as predict_mock: + yield predict_mock + + @pytest.fixture def predict_client_explain_mock(): with mock.patch.object( @@ -480,6 +497,23 @@ def predict_client_explain_mock(): yield predict_mock +@pytest.fixture +def predict_async_client_explain_mock(): + response = gca_prediction_service.ExplainResponse( + deployed_model_id=_TEST_MODEL_ID, + ) + response.predictions.extend(_TEST_PREDICTION) + response.explanations.extend(_TEST_EXPLANATIONS) + response.explanations[0].attributions.extend(_TEST_ATTRIBUTIONS) + + with mock.patch.object( + target=prediction_service_async_client.PredictionServiceAsyncClient, + attribute="explain", + return_value=response, + ) as explain_mock: + yield explain_mock + + @pytest.fixture def get_drp_mock(): with mock.patch.object( @@ -624,6 +658,12 @@ def test_constructor(self, create_endpoint_client_mock): location_override=_TEST_LOCATION, appended_user_agent=None, ), + mock.call( + client_class=utils.PredictionAsyncClientWithOverride, + credentials=None, + location_override=_TEST_LOCATION, + prediction_client=True, + ), mock.call( client_class=utils.PredictionClientWithOverride, credentials=None, @@ -714,6 +754,12 @@ def test_constructor_with_custom_credentials(self, create_endpoint_client_mock): location_override=_TEST_LOCATION, appended_user_agent=None, ), + mock.call( + client_class=utils.PredictionAsyncClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION, + prediction_client=True, + ), mock.call( client_class=utils.PredictionClientWithOverride, credentials=creds, @@ -1841,6 +1887,30 @@ def test_predict(self, predict_client_predict_mock): timeout=None, ) + @pytest.mark.asyncio + @pytest.mark.usefixtures("get_endpoint_mock") + async def test_predict_async(self, predict_async_client_predict_mock): + """Tests the Endpoint.predict_async method.""" + test_endpoint = models.Endpoint(_TEST_ID) + test_prediction = await test_endpoint.predict_async( + instances=_TEST_INSTANCES, parameters={"param": 3.0} + ) + + true_prediction = models.Prediction( + predictions=_TEST_PREDICTION, + deployed_model_id=_TEST_ID, + model_version_id=_TEST_VERSION_ID, + model_resource_name=_TEST_MODEL_NAME, + ) + + assert true_prediction == test_prediction + predict_async_client_predict_mock.assert_called_once_with( + endpoint=_TEST_ENDPOINT_NAME, + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + timeout=None, + ) + @pytest.mark.usefixtures("get_endpoint_mock") def test_explain(self, predict_client_explain_mock): @@ -1868,6 +1938,34 @@ def test_explain(self, predict_client_explain_mock): timeout=None, ) + @pytest.mark.asyncio + @pytest.mark.usefixtures("get_endpoint_mock") + async def test_explain_async(self, predict_async_client_explain_mock): + """Tests the Endpoint.explain_async method.""" + test_endpoint = models.Endpoint(_TEST_ID) + test_prediction = await test_endpoint.explain_async( + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + deployed_model_id=_TEST_MODEL_ID, + ) + expected_explanations = _TEST_EXPLANATIONS + expected_explanations[0].attributions.extend(_TEST_ATTRIBUTIONS) + + expected_prediction = models.Prediction( + predictions=_TEST_PREDICTION, + deployed_model_id=_TEST_ID, + explanations=expected_explanations, + ) + + assert expected_prediction == test_prediction + predict_async_client_explain_mock.assert_called_once_with( + endpoint=_TEST_ENDPOINT_NAME, + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + deployed_model_id=_TEST_MODEL_ID, + timeout=None, + ) + @pytest.mark.usefixtures("get_endpoint_mock") def test_predict_with_timeout(self, predict_client_predict_mock):