Skip to content

Commit

Permalink
fix: Fix create_lit_model_from_endpoint not accepting models that don…
Browse files Browse the repository at this point in the history
…'t return a dictionary. (#1020)

Some models, like Keras squential models, don't return a dictionary for their prediction. We need to support these models as it is commonly used.

Fixes b/220167889
  • Loading branch information
taiseiak authored Mar 2, 2022
1 parent e7d2719 commit b9a057d
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 15 deletions.
9 changes: 7 additions & 2 deletions google/cloud/aiplatform/explain/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os

from google.cloud import aiplatform
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Mapping, Optional, Tuple, Union

try:
from lit_nlp.api import dataset as lit_dataset
Expand Down Expand Up @@ -154,7 +154,12 @@ def predict_minibatch(
prediction_object = self._endpoint.predict(instances)
outputs = []
for prediction in prediction_object.predictions:
outputs.append({key: prediction[key] for key in self._output_types})
if isinstance(prediction, Mapping):
outputs.append({key: prediction[key] for key in self._output_types})
else:
outputs.append(
{key: prediction[i] for i, key in enumerate(self._output_types)}
)
if self._explanation_enabled:
for i, explanation in enumerate(prediction_object.explanations):
attributions = explanation.attributions
Expand Down
159 changes: 146 additions & 13 deletions tests/unit/aiplatform/test_explain_lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@
),
]
_TEST_TRAFFIC_SPLIT = {_TEST_ID: 0, _TEST_ID_2: 100, _TEST_ID_3: 0}
_TEST_PREDICTION = [{"label": 1.0}]
_TEST_DICT_PREDICTION = [{"label": 1.0}]
_TEST_LIST_PREDICTION = [[1.0]]
_TEST_EXPLANATIONS = [gca_prediction_service.explanation.Explanation(attributions=[])]
_TEST_ATTRIBUTIONS = [
gca_prediction_service.explanation.Attribution(
Expand Down Expand Up @@ -218,26 +219,54 @@ def get_endpoint_with_models_with_explanation_mock():


@pytest.fixture
def predict_client_predict_mock():
def predict_client_predict_dict_mock():
with mock.patch.object(
prediction_service_client.PredictionServiceClient, "predict"
) as predict_mock:
predict_mock.return_value = gca_prediction_service.PredictResponse(
deployed_model_id=_TEST_ID
)
predict_mock.return_value.predictions.extend(_TEST_PREDICTION)
predict_mock.return_value.predictions.extend(_TEST_DICT_PREDICTION)
yield predict_mock


@pytest.fixture
def predict_client_explain_mock():
def predict_client_explain_dict_mock():
with mock.patch.object(
prediction_service_client.PredictionServiceClient, "explain"
) as predict_mock:
predict_mock.return_value = gca_prediction_service.ExplainResponse(
deployed_model_id=_TEST_ID,
)
predict_mock.return_value.predictions.extend(_TEST_PREDICTION)
predict_mock.return_value.predictions.extend(_TEST_DICT_PREDICTION)
predict_mock.return_value.explanations.extend(_TEST_EXPLANATIONS)
predict_mock.return_value.explanations[0].attributions.extend(
_TEST_ATTRIBUTIONS
)
yield predict_mock


@pytest.fixture
def predict_client_predict_list_mock():
with mock.patch.object(
prediction_service_client.PredictionServiceClient, "predict"
) as predict_mock:
predict_mock.return_value = gca_prediction_service.PredictResponse(
deployed_model_id=_TEST_ID
)
predict_mock.return_value.predictions.extend(_TEST_LIST_PREDICTION)
yield predict_mock


@pytest.fixture
def predict_client_explain_list_mock():
with mock.patch.object(
prediction_service_client.PredictionServiceClient, "explain"
) as predict_mock:
predict_mock.return_value = gca_prediction_service.ExplainResponse(
deployed_model_id=_TEST_ID,
)
predict_mock.return_value.predictions.extend(_TEST_LIST_PREDICTION)
predict_mock.return_value.explanations.extend(_TEST_EXPLANATIONS)
predict_mock.return_value.explanations[0].attributions.extend(
_TEST_ATTRIBUTIONS
Expand Down Expand Up @@ -312,10 +341,112 @@ def test_create_lit_model_from_tensorflow_with_xai_returns_model(
assert len(item.values()) == 2

@pytest.mark.usefixtures(
"predict_client_predict_mock", "get_endpoint_with_models_mock"
"predict_client_predict_dict_mock", "get_endpoint_with_models_mock"
)
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
def test_create_lit_model_from_dict_endpoint_returns_model(
self, feature_types, label_types, model_id
):
endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)
lit_model = create_lit_model_from_endpoint(
endpoint, feature_types, label_types, model_id
)
test_inputs = [
{"feature_1": 1.0, "feature_2": 2.0},
]
outputs = lit_model.predict_minibatch(test_inputs)

assert lit_model.input_spec() == dict(feature_types)
assert lit_model.output_spec() == dict(label_types)
assert len(outputs) == 1
for item in outputs:
assert item.keys() == {"label"}
assert len(item.values()) == 1

@pytest.mark.usefixtures(
"predict_client_explain_dict_mock",
"get_endpoint_with_models_with_explanation_mock",
)
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
def test_create_lit_model_from_dict_endpoint_with_xai_returns_model(
self, feature_types, label_types, model_id
):
endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)
lit_model = create_lit_model_from_endpoint(
endpoint, feature_types, label_types, model_id
)
test_inputs = [
{"feature_1": 1.0, "feature_2": 2.0},
]
outputs = lit_model.predict_minibatch(test_inputs)

assert lit_model.input_spec() == dict(feature_types)
assert lit_model.output_spec() == dict(
{
**label_types,
"feature_attribution": lit_types.FeatureSalience(signed=True),
}
)
assert len(outputs) == 1
for item in outputs:
assert item.keys() == {"label", "feature_attribution"}
assert len(item.values()) == 2

@pytest.mark.usefixtures(
"predict_client_predict_dict_mock", "get_endpoint_with_models_mock"
)
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
def test_create_lit_model_from_dict_endpoint_name_returns_model(
self, feature_types, label_types, model_id
):
lit_model = create_lit_model_from_endpoint(
_TEST_ENDPOINT_NAME, feature_types, label_types, model_id
)
test_inputs = [
{"feature_1": 1.0, "feature_2": 2.0},
]
outputs = lit_model.predict_minibatch(test_inputs)

assert lit_model.input_spec() == dict(feature_types)
assert lit_model.output_spec() == dict(label_types)
assert len(outputs) == 1
for item in outputs:
assert item.keys() == {"label"}
assert len(item.values()) == 1

@pytest.mark.usefixtures(
"predict_client_explain_dict_mock",
"get_endpoint_with_models_with_explanation_mock",
)
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
def test_create_lit_model_from_dict_endpoint_name_with_xai_returns_model(
self, feature_types, label_types, model_id
):
lit_model = create_lit_model_from_endpoint(
_TEST_ENDPOINT_NAME, feature_types, label_types, model_id
)
test_inputs = [
{"feature_1": 1.0, "feature_2": 2.0},
]
outputs = lit_model.predict_minibatch(test_inputs)

assert lit_model.input_spec() == dict(feature_types)
assert lit_model.output_spec() == dict(
{
**label_types,
"feature_attribution": lit_types.FeatureSalience(signed=True),
}
)
assert len(outputs) == 1
for item in outputs:
assert item.keys() == {"label", "feature_attribution"}
assert len(item.values()) == 2

@pytest.mark.usefixtures(
"predict_client_predict_list_mock", "get_endpoint_with_models_mock"
)
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
def test_create_lit_model_from_endpoint_returns_model(
def test_create_lit_model_from_list_endpoint_returns_model(
self, feature_types, label_types, model_id
):
endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)
Expand All @@ -335,10 +466,11 @@ def test_create_lit_model_from_endpoint_returns_model(
assert len(item.values()) == 1

@pytest.mark.usefixtures(
"predict_client_explain_mock", "get_endpoint_with_models_with_explanation_mock"
"predict_client_explain_list_mock",
"get_endpoint_with_models_with_explanation_mock",
)
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
def test_create_lit_model_from_endpoint_with_xai_returns_model(
def test_create_lit_model_from_list_endpoint_with_xai_returns_model(
self, feature_types, label_types, model_id
):
endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME)
Expand All @@ -363,10 +495,10 @@ def test_create_lit_model_from_endpoint_with_xai_returns_model(
assert len(item.values()) == 2

@pytest.mark.usefixtures(
"predict_client_predict_mock", "get_endpoint_with_models_mock"
"predict_client_predict_list_mock", "get_endpoint_with_models_mock"
)
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
def test_create_lit_model_from_endpoint_name_returns_model(
def test_create_lit_model_from_list_endpoint_name_returns_model(
self, feature_types, label_types, model_id
):
lit_model = create_lit_model_from_endpoint(
Expand All @@ -385,10 +517,11 @@ def test_create_lit_model_from_endpoint_name_returns_model(
assert len(item.values()) == 1

@pytest.mark.usefixtures(
"predict_client_explain_mock", "get_endpoint_with_models_with_explanation_mock"
"predict_client_explain_list_mock",
"get_endpoint_with_models_with_explanation_mock",
)
@pytest.mark.parametrize("model_id", [None, _TEST_ID])
def test_create_lit_model_from_endpoint_name_with_xai_returns_model(
def test_create_lit_model_from_list_endpoint_name_with_xai_returns_model(
self, feature_types, label_types, model_id
):
lit_model = create_lit_model_from_endpoint(
Expand Down

0 comments on commit b9a057d

Please sign in to comment.