Skip to content

Commit

Permalink
feat: support raw_predict for Endpoint (#1620)
Browse files Browse the repository at this point in the history
* feat: support raw_predict for Endpoints

* formatting

* fixed broken unit test

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* remove commented out code blocks

* removing debug print statements

* removing extra prints

* update copyright header date

* removing automatically added python 3.6 support for kokoro

* addressed PR comments

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* adding unit test

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* removed unused import

* renamed raw predict constants

* modified error messages

* added doc strings

* fixed typo in doc strings

* removed extra space

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
Co-authored-by: nayaknishant <nishantnayak@google.com>
  • Loading branch information
3 people authored Sep 2, 2022
1 parent 39e3be9 commit cc7c968
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 34 deletions.
8 changes: 6 additions & 2 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _are_futures_done(self) -> bool:
return self.__latest_future is None

def wait(self):
"""Helper method to that blocks until all futures are complete."""
"""Helper method that blocks until all futures are complete."""
future = self.__latest_future
if future:
futures.wait([future], return_when=futures.FIRST_EXCEPTION)
Expand Down Expand Up @@ -974,7 +974,11 @@ def _sync_object_with_future_result(
"_gca_resource",
"credentials",
]
optional_sync_attributes = ["_prediction_client"]
optional_sync_attributes = [
"_prediction_client",
"_authorized_session",
"_raw_predict_request_url",
]

for attribute in sync_attributes:
setattr(self, attribute, getattr(result, attribute))
Expand Down
3 changes: 3 additions & 0 deletions google/cloud/aiplatform/constants/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@
# that is being used for usage metrics tracking purposes.
# For more details on go/oneplatform-api-analytics
USER_AGENT_SDK_COMMAND = ""

# Needed for Endpoint.raw_predict
DEFAULT_AUTHED_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
143 changes: 117 additions & 26 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import re
import shutil
import tempfile
import requests
from typing import (
Any,
Dict,
Expand All @@ -35,9 +36,11 @@
from google.api_core import operation
from google.api_core import exceptions as api_exceptions
from google.auth import credentials as auth_credentials
from google.auth.transport import requests as google_auth_requests

from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import constants
from google.cloud.aiplatform import explain
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import jobs
Expand Down Expand Up @@ -69,6 +72,8 @@
_DEFAULT_MACHINE_TYPE = "n1-standard-2"
_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"
_SUCCESSFUL_HTTP_RESPONSE = 300
_RAW_PREDICT_DEPLOYED_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id"
_RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model"

_LOGGER = base.Logger(__name__)

Expand Down Expand Up @@ -200,6 +205,8 @@ def __init__(
location=self.location,
credentials=credentials,
)
self.authorized_session = None
self.raw_predict_request_url = None

def _skipped_getter_call(self) -> bool:
"""Check if GAPIC resource was populated by call to get/list API methods
Expand Down Expand Up @@ -1389,16 +1396,15 @@ def update(
"""Updates an endpoint.
Example usage:
my_endpoint = my_endpoint.update(
display_name='my-updated-endpoint',
description='my updated description',
labels={'key': 'value'},
traffic_split={
'123456': 20,
'234567': 80,
},
)
my_endpoint = my_endpoint.update(
display_name='my-updated-endpoint',
description='my updated description',
labels={'key': 'value'},
traffic_split={
'123456': 20,
'234567': 80,
},
)
Args:
display_name (str):
Expand Down Expand Up @@ -1481,6 +1487,7 @@ def predict(
instances: List,
parameters: Optional[Dict] = None,
timeout: Optional[float] = None,
use_raw_predict: Optional[bool] = False,
) -> Prediction:
"""Make a prediction against this Endpoint.
Expand All @@ -1505,29 +1512,80 @@ def predict(
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``parameters_schema_uri``.
timeout (float): Optional. The timeout for this request in seconds.
use_raw_predict (bool):
Optional. Default value is False. If set to True, the underlying prediction call will be made
against Endpoint.raw_predict(). Note that model version information will
not be available in the prediciton response using raw_predict.
Returns:
prediction (aiplatform.Prediction):
Prediction with returned predictions and Model ID.
"""
self.wait()
if use_raw_predict:
raw_predict_response = self.raw_predict(
body=json.dumps({"instances": instances, "parameters": parameters}),
headers={"Content-Type": "application/json"},
)
json_response = json.loads(raw_predict_response.text)
return Prediction(
predictions=json_response["predictions"],
deployed_model_id=raw_predict_response.headers[
_RAW_PREDICT_DEPLOYED_MODEL_ID_KEY
],
model_resource_name=raw_predict_response.headers[
_RAW_PREDICT_MODEL_RESOURCE_KEY
],
)
else:
prediction_response = self._prediction_client.predict(
endpoint=self._gca_resource.name,
instances=instances,
parameters=parameters,
timeout=timeout,
)

prediction_response = self._prediction_client.predict(
endpoint=self._gca_resource.name,
instances=instances,
parameters=parameters,
timeout=timeout,
)
return Prediction(
predictions=[
json_format.MessageToDict(item)
for item in prediction_response.predictions.pb
],
deployed_model_id=prediction_response.deployed_model_id,
model_version_id=prediction_response.model_version_id,
model_resource_name=prediction_response.model,
)

return Prediction(
predictions=[
json_format.MessageToDict(item)
for item in prediction_response.predictions.pb
],
deployed_model_id=prediction_response.deployed_model_id,
model_version_id=prediction_response.model_version_id,
model_resource_name=prediction_response.model,
)
def raw_predict(
self, body: bytes, headers: Dict[str, str]
) -> requests.models.Response:
"""Makes a prediction request using arbitrary headers.
Example usage:
my_endpoint = aiplatform.Endpoint(ENDPOINT_ID)
response = my_endpoint.raw_predict(
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
headers = {'Content-Type':'application/json'}
)
status_code = response.status_code
results = json.dumps(response.text)
Args:
body (bytes):
The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
headers (Dict[str, str]):
The header of the request as a dictionary. There are no restrictions on the header.
Returns:
A requests.models.Response object containing the status code and prediction results.
"""
if not self.authorized_session:
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
self.authorized_session = google_auth_requests.AuthorizedSession(
self.credentials
)
self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict"

return self.authorized_session.post(self.raw_predict_request_url, body, headers)

def explain(
self,
Expand Down Expand Up @@ -2004,7 +2062,7 @@ def _http_request(
def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction:
"""Make a prediction against this PrivateEndpoint using a HTTP request.
This method must be called within the network the PrivateEndpoint is peered to.
The predict() call will fail otherwise. To check, use `PrivateEndpoint.network`.
Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`.
Example usage:
response = my_private_endpoint.predict(instances=[...])
Expand Down Expand Up @@ -2062,6 +2120,39 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
deployed_model_id=self._gca_resource.deployed_models[0].id,
)

def raw_predict(
self, body: bytes, headers: Dict[str, str]
) -> requests.models.Response:
"""Make a prediction request using arbitrary headers.
This method must be called within the network the PrivateEndpoint is peered to.
Otherwise, the predict() call will fail with error code 404. To check, use `PrivateEndpoint.network`.
Example usage:
my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID)
response = my_endpoint.raw_predict(
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
headers = {'Content-Type':'application/json'}
)
status_code = response.status_code
results = json.dumps(response.text)
Args:
body (bytes):
The body of the prediction request in bytes. This must not exceed 1.5 mb per request.
headers (Dict[str, str]):
The header of the request as a dictionary. There are no restrictions on the header.
Returns:
A requests.models.Response object containing the status code and prediction results.
"""
self.wait()
return self._http_request(
method="POST",
url=self.predict_http_uri,
body=body,
headers=headers,
)

def explain(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'explain' as of now."
Expand Down
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@
"uvicorn >= 0.16.0",
]

private_endpoints_extra_require = [
"urllib3 >=1.21.1, <1.27",
]
endpoint_extra_require = ["requests >= 2.28.1"]

private_endpoints_extra_require = ["urllib3 >=1.21.1, <1.27", "requests >= 2.28.1"]
full_extra_require = list(
set(
tensorboard_extra_require
Expand All @@ -92,6 +92,7 @@
+ featurestore_extra_require
+ pipelines_extra_require
+ datasets_extra_require
+ endpoint_extra_require
+ vizier_extra_require
+ prediction_extra_require
+ private_endpoints_extra_require
Expand Down Expand Up @@ -136,6 +137,7 @@
"google-cloud-resource-manager >= 1.3.3, < 3.0.0dev",
),
extras_require={
"endpoint": endpoint_extra_require,
"full": full_extra_require,
"metadata": metadata_extra_require,
"tensorboard": tensorboard_extra_require,
Expand Down
61 changes: 61 additions & 0 deletions tests/system/aiplatform/test_model_interactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json

from google.cloud import aiplatform

from tests.system.aiplatform import e2e_base

_PERMANENT_IRIS_ENDPOINT_ID = "4966625964059525120"
_PREDICTION_INSTANCE = {
"petal_length": "3.0",
"petal_width": "3.0",
"sepal_length": "3.0",
"sepal_width": "3.0",
}


class TestModelInteractions(e2e_base.TestEndToEnd):
_temp_prefix = ""
endpoint = aiplatform.Endpoint(_PERMANENT_IRIS_ENDPOINT_ID)

def test_prediction(self):
# test basic predict
prediction_response = self.endpoint.predict(instances=[_PREDICTION_INSTANCE])
assert len(prediction_response.predictions) == 1

# test predict(use_raw_predict = True)
prediction_with_raw_predict = self.endpoint.predict(
instances=[_PREDICTION_INSTANCE], use_raw_predict=True
)
assert (
prediction_with_raw_predict.deployed_model_id
== prediction_response.deployed_model_id
)
assert (
prediction_with_raw_predict.model_resource_name
== prediction_response.model_resource_name
)

# test raw_predict
raw_prediction_response = self.endpoint.raw_predict(
json.dumps({"instances": [_PREDICTION_INSTANCE]}),
{"Content-Type": "application/json"},
)
assert raw_prediction_response.status_code == 200
assert len(json.loads(raw_prediction_response.text)) == 1
5 changes: 2 additions & 3 deletions tests/system/aiplatform/test_model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


@pytest.mark.usefixtures("delete_staging_bucket", "tear_down_resources")
class TestModel(e2e_base.TestEndToEnd):
class TestModelUploadAndUpdate(e2e_base.TestEndToEnd):

_temp_prefix = "temp_vertex_sdk_e2e_model_upload_test"

Expand Down Expand Up @@ -65,9 +65,8 @@ def test_upload_and_deploy_xgboost_model(self, shared_state):
# See https://github.com/googleapis/python-aiplatform/issues/773
endpoint = model.deploy(machine_type="n1-standard-2")
shared_state["resources"].append(endpoint)
predict_response = endpoint.predict(instances=[[0, 0, 0]])
assert len(predict_response.predictions) == 1

# test model update
model = model.update(
display_name="new_name",
description="new_description",
Expand Down
Loading

0 comments on commit cc7c968

Please sign in to comment.