diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index e88de02934..55ca1c7829 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -241,7 +241,7 @@ def _are_futures_done(self) -> bool: return self.__latest_future is None def wait(self): - """Helper method to that blocks until all futures are complete.""" + """Helper method that blocks until all futures are complete.""" future = self.__latest_future if future: futures.wait([future], return_when=futures.FIRST_EXCEPTION) @@ -974,7 +974,11 @@ def _sync_object_with_future_result( "_gca_resource", "credentials", ] - optional_sync_attributes = ["_prediction_client"] + optional_sync_attributes = [ + "_prediction_client", + "_authorized_session", + "_raw_predict_request_url", + ] for attribute in sync_attributes: setattr(self, attribute, getattr(result, attribute)) diff --git a/google/cloud/aiplatform/constants/base.py b/google/cloud/aiplatform/constants/base.py index 8c1bc1b613..9af5e339bf 100644 --- a/google/cloud/aiplatform/constants/base.py +++ b/google/cloud/aiplatform/constants/base.py @@ -92,3 +92,6 @@ # that is being used for usage metrics tracking purposes. # For more details on go/oneplatform-api-analytics USER_AGENT_SDK_COMMAND = "" + +# Needed for Endpoint.raw_predict +DEFAULT_AUTHED_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 6b50d06f70..54a7589429 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -20,6 +20,7 @@ import re import shutil import tempfile +import requests from typing import ( Any, Dict, @@ -35,9 +36,11 @@ from google.api_core import operation from google.api_core import exceptions as api_exceptions from google.auth import credentials as auth_credentials +from google.auth.transport import requests as google_auth_requests from google.cloud import aiplatform from google.cloud.aiplatform import base +from google.cloud.aiplatform import constants from google.cloud.aiplatform import explain from google.cloud.aiplatform import initializer from google.cloud.aiplatform import jobs @@ -69,6 +72,8 @@ _DEFAULT_MACHINE_TYPE = "n1-standard-2" _DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0" _SUCCESSFUL_HTTP_RESPONSE = 300 +_RAW_PREDICT_DEPLOYED_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id" +_RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model" _LOGGER = base.Logger(__name__) @@ -200,6 +205,8 @@ def __init__( location=self.location, credentials=credentials, ) + self.authorized_session = None + self.raw_predict_request_url = None def _skipped_getter_call(self) -> bool: """Check if GAPIC resource was populated by call to get/list API methods @@ -1389,16 +1396,15 @@ def update( """Updates an endpoint. Example usage: - - my_endpoint = my_endpoint.update( - display_name='my-updated-endpoint', - description='my updated description', - labels={'key': 'value'}, - traffic_split={ - '123456': 20, - '234567': 80, - }, - ) + my_endpoint = my_endpoint.update( + display_name='my-updated-endpoint', + description='my updated description', + labels={'key': 'value'}, + traffic_split={ + '123456': 20, + '234567': 80, + }, + ) Args: display_name (str): @@ -1481,6 +1487,7 @@ def predict( instances: List, parameters: Optional[Dict] = None, timeout: Optional[float] = None, + use_raw_predict: Optional[bool] = False, ) -> Prediction: """Make a prediction against this Endpoint. @@ -1505,29 +1512,80 @@ def predict( [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``parameters_schema_uri``. timeout (float): Optional. The timeout for this request in seconds. + use_raw_predict (bool): + Optional. Default value is False. If set to True, the underlying prediction call will be made + against Endpoint.raw_predict(). Note that model version information will + not be available in the prediciton response using raw_predict. Returns: prediction (aiplatform.Prediction): Prediction with returned predictions and Model ID. """ self.wait() + if use_raw_predict: + raw_predict_response = self.raw_predict( + body=json.dumps({"instances": instances, "parameters": parameters}), + headers={"Content-Type": "application/json"}, + ) + json_response = json.loads(raw_predict_response.text) + return Prediction( + predictions=json_response["predictions"], + deployed_model_id=raw_predict_response.headers[ + _RAW_PREDICT_DEPLOYED_MODEL_ID_KEY + ], + model_resource_name=raw_predict_response.headers[ + _RAW_PREDICT_MODEL_RESOURCE_KEY + ], + ) + else: + prediction_response = self._prediction_client.predict( + endpoint=self._gca_resource.name, + instances=instances, + parameters=parameters, + timeout=timeout, + ) - prediction_response = self._prediction_client.predict( - endpoint=self._gca_resource.name, - instances=instances, - parameters=parameters, - timeout=timeout, - ) + return Prediction( + predictions=[ + json_format.MessageToDict(item) + for item in prediction_response.predictions.pb + ], + deployed_model_id=prediction_response.deployed_model_id, + model_version_id=prediction_response.model_version_id, + model_resource_name=prediction_response.model, + ) - return Prediction( - predictions=[ - json_format.MessageToDict(item) - for item in prediction_response.predictions.pb - ], - deployed_model_id=prediction_response.deployed_model_id, - model_version_id=prediction_response.model_version_id, - model_resource_name=prediction_response.model, - ) + def raw_predict( + self, body: bytes, headers: Dict[str, str] + ) -> requests.models.Response: + """Makes a prediction request using arbitrary headers. + + Example usage: + my_endpoint = aiplatform.Endpoint(ENDPOINT_ID) + response = my_endpoint.raw_predict( + body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}' + headers = {'Content-Type':'application/json'} + ) + status_code = response.status_code + results = json.dumps(response.text) + + Args: + body (bytes): + The body of the prediction request in bytes. This must not exceed 1.5 mb per request. + headers (Dict[str, str]): + The header of the request as a dictionary. There are no restrictions on the header. + + Returns: + A requests.models.Response object containing the status code and prediction results. + """ + if not self.authorized_session: + self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES + self.authorized_session = google_auth_requests.AuthorizedSession( + self.credentials + ) + self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict" + + return self.authorized_session.post(self.raw_predict_request_url, body, headers) def explain( self, @@ -2004,7 +2062,7 @@ def _http_request( def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction: """Make a prediction against this PrivateEndpoint using a HTTP request. This method must be called within the network the PrivateEndpoint is peered to. - The predict() call will fail otherwise. To check, use `PrivateEndpoint.network`. + Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`. Example usage: response = my_private_endpoint.predict(instances=[...]) @@ -2062,6 +2120,39 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict deployed_model_id=self._gca_resource.deployed_models[0].id, ) + def raw_predict( + self, body: bytes, headers: Dict[str, str] + ) -> requests.models.Response: + """Make a prediction request using arbitrary headers. + This method must be called within the network the PrivateEndpoint is peered to. + Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`. + + Example usage: + my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID) + response = my_endpoint.raw_predict( + body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}' + headers = {'Content-Type':'application/json'} + ) + status_code = response.status_code + results = json.dumps(response.text) + + Args: + body (bytes): + The body of the prediction request in bytes. This must not exceed 1.5 mb per request. + headers (Dict[str, str]): + The header of the request as a dictionary. There are no restrictions on the header. + + Returns: + A requests.models.Response object containing the status code and prediction results. + """ + self.wait() + return self._http_request( + method="POST", + url=self.predict_http_uri, + body=body, + headers=headers, + ) + def explain(self): raise NotImplementedError( f"{self.__class__.__name__} class does not support 'explain' as of now." diff --git a/setup.py b/setup.py index 534cd91932..5e8165c12b 100644 --- a/setup.py +++ b/setup.py @@ -80,9 +80,9 @@ "uvicorn >= 0.16.0", ] -private_endpoints_extra_require = [ - "urllib3 >=1.21.1, <1.27", -] +endpoint_extra_require = ["requests >= 2.28.1"] + +private_endpoints_extra_require = ["urllib3 >=1.21.1, <1.27", "requests >= 2.28.1"] full_extra_require = list( set( tensorboard_extra_require @@ -92,6 +92,7 @@ + featurestore_extra_require + pipelines_extra_require + datasets_extra_require + + endpoint_extra_require + vizier_extra_require + prediction_extra_require + private_endpoints_extra_require @@ -136,6 +137,7 @@ "google-cloud-resource-manager >= 1.3.3, < 3.0.0dev", ), extras_require={ + "endpoint": endpoint_extra_require, "full": full_extra_require, "metadata": metadata_extra_require, "tensorboard": tensorboard_extra_require, diff --git a/tests/system/aiplatform/test_model_interactions.py b/tests/system/aiplatform/test_model_interactions.py new file mode 100644 index 0000000000..ff9a28139e --- /dev/null +++ b/tests/system/aiplatform/test_model_interactions.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json + +from google.cloud import aiplatform + +from tests.system.aiplatform import e2e_base + +_PERMANENT_IRIS_ENDPOINT_ID = "4966625964059525120" +_PREDICTION_INSTANCE = { + "petal_length": "3.0", + "petal_width": "3.0", + "sepal_length": "3.0", + "sepal_width": "3.0", +} + + +class TestModelInteractions(e2e_base.TestEndToEnd): + _temp_prefix = "" + endpoint = aiplatform.Endpoint(_PERMANENT_IRIS_ENDPOINT_ID) + + def test_prediction(self): + # test basic predict + prediction_response = self.endpoint.predict(instances=[_PREDICTION_INSTANCE]) + assert len(prediction_response.predictions) == 1 + + # test predict(use_raw_predict = True) + prediction_with_raw_predict = self.endpoint.predict( + instances=[_PREDICTION_INSTANCE], use_raw_predict=True + ) + assert ( + prediction_with_raw_predict.deployed_model_id + == prediction_response.deployed_model_id + ) + assert ( + prediction_with_raw_predict.model_resource_name + == prediction_response.model_resource_name + ) + + # test raw_predict + raw_prediction_response = self.endpoint.raw_predict( + json.dumps({"instances": [_PREDICTION_INSTANCE]}), + {"Content-Type": "application/json"}, + ) + assert raw_prediction_response.status_code == 200 + assert len(json.loads(raw_prediction_response.text)) == 1 diff --git a/tests/system/aiplatform/test_model_upload.py b/tests/system/aiplatform/test_model_upload.py index 48e6169af4..b019982a72 100644 --- a/tests/system/aiplatform/test_model_upload.py +++ b/tests/system/aiplatform/test_model_upload.py @@ -29,7 +29,7 @@ @pytest.mark.usefixtures("delete_staging_bucket", "tear_down_resources") -class TestModel(e2e_base.TestEndToEnd): +class TestModelUploadAndUpdate(e2e_base.TestEndToEnd): _temp_prefix = "temp_vertex_sdk_e2e_model_upload_test" @@ -65,9 +65,8 @@ def test_upload_and_deploy_xgboost_model(self, shared_state): # See https://github.com/googleapis/python-aiplatform/issues/773 endpoint = model.deploy(machine_type="n1-standard-2") shared_state["resources"].append(endpoint) - predict_response = endpoint.predict(instances=[[0, 0, 0]]) - assert len(predict_response.predictions) == 1 + # test model update model = model.update( display_name="new_name", description="new_description", diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 71d7b538a2..8521854c34 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -19,6 +19,7 @@ from concurrent import futures import pathlib import pytest +import requests from unittest import mock from unittest.mock import patch @@ -31,6 +32,7 @@ from google.cloud.aiplatform import initializer from google.cloud.aiplatform import models from google.cloud.aiplatform import utils +from google.cloud.aiplatform import constants from google.cloud.aiplatform.compat.services import ( endpoint_service_client, @@ -309,6 +311,10 @@ _TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_ID}" +_TEST_RAW_PREDICT_URL = f"https://{_TEST_LOCATION}-{constants.base.API_BASE_PATH}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict" +_TEST_RAW_PREDICT_DATA = b"" +_TEST_RAW_PREDICT_HEADER = {"Content-Type": "application/json"} + @pytest.fixture def mock_model(): @@ -329,6 +335,22 @@ def update_model_mock(mock_model): yield mock +@pytest.fixture +def authorized_session_mock(): + with patch( + "google.auth.transport.requests.AuthorizedSession" + ) as MockAuthorizedSession: + mock_auth_session = MockAuthorizedSession(_TEST_CREDENTIALS) + yield mock_auth_session + + +@pytest.fixture +def raw_predict_mock(authorized_session_mock): + with patch.object(authorized_session_mock, "post") as mock_post: + mock_post.return_value = requests.models.Response() + yield mock_post + + @pytest.fixture def get_endpoint_mock(): with mock.patch.object( @@ -2707,3 +2729,16 @@ def test_list(self, list_models_mock): assert listed_model.versioning_registry assert listed_model._revisioned_resource_id_validator + + @pytest.mark.usefixtures( + "get_endpoint_mock", + "get_model_mock", + "create_endpoint_mock", + "raw_predict_mock", + ) + def test_raw_predict(self, raw_predict_mock): + test_endpoint = models.Endpoint(_TEST_ID) + test_endpoint.raw_predict(_TEST_RAW_PREDICT_DATA, _TEST_RAW_PREDICT_HEADER) + raw_predict_mock.assert_called_once_with( + _TEST_RAW_PREDICT_URL, _TEST_RAW_PREDICT_DATA, _TEST_RAW_PREDICT_HEADER + )