diff --git a/google/cloud/aiplatform/explain/lit.py b/google/cloud/aiplatform/explain/lit.py index 5032055801..635ebb1ce8 100644 --- a/google/cloud/aiplatform/explain/lit.py +++ b/google/cloud/aiplatform/explain/lit.py @@ -18,7 +18,7 @@ import os from google.cloud import aiplatform -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Mapping, Optional, Tuple, Union try: from lit_nlp.api import dataset as lit_dataset @@ -154,7 +154,12 @@ def predict_minibatch( prediction_object = self._endpoint.predict(instances) outputs = [] for prediction in prediction_object.predictions: - outputs.append({key: prediction[key] for key in self._output_types}) + if isinstance(prediction, Mapping): + outputs.append({key: prediction[key] for key in self._output_types}) + else: + outputs.append( + {key: prediction[i] for i, key in enumerate(self._output_types)} + ) if self._explanation_enabled: for i, explanation in enumerate(prediction_object.explanations): attributions = explanation.attributions diff --git a/tests/unit/aiplatform/test_explain_lit.py b/tests/unit/aiplatform/test_explain_lit.py index c8092b1742..fe9b269610 100644 --- a/tests/unit/aiplatform/test_explain_lit.py +++ b/tests/unit/aiplatform/test_explain_lit.py @@ -105,7 +105,8 @@ ), ] _TEST_TRAFFIC_SPLIT = {_TEST_ID: 0, _TEST_ID_2: 100, _TEST_ID_3: 0} -_TEST_PREDICTION = [{"label": 1.0}] +_TEST_DICT_PREDICTION = [{"label": 1.0}] +_TEST_LIST_PREDICTION = [[1.0]] _TEST_EXPLANATIONS = [gca_prediction_service.explanation.Explanation(attributions=[])] _TEST_ATTRIBUTIONS = [ gca_prediction_service.explanation.Attribution( @@ -218,26 +219,54 @@ def get_endpoint_with_models_with_explanation_mock(): @pytest.fixture -def predict_client_predict_mock(): +def predict_client_predict_dict_mock(): with mock.patch.object( prediction_service_client.PredictionServiceClient, "predict" ) as predict_mock: predict_mock.return_value = gca_prediction_service.PredictResponse( deployed_model_id=_TEST_ID ) - predict_mock.return_value.predictions.extend(_TEST_PREDICTION) + predict_mock.return_value.predictions.extend(_TEST_DICT_PREDICTION) yield predict_mock @pytest.fixture -def predict_client_explain_mock(): +def predict_client_explain_dict_mock(): with mock.patch.object( prediction_service_client.PredictionServiceClient, "explain" ) as predict_mock: predict_mock.return_value = gca_prediction_service.ExplainResponse( deployed_model_id=_TEST_ID, ) - predict_mock.return_value.predictions.extend(_TEST_PREDICTION) + predict_mock.return_value.predictions.extend(_TEST_DICT_PREDICTION) + predict_mock.return_value.explanations.extend(_TEST_EXPLANATIONS) + predict_mock.return_value.explanations[0].attributions.extend( + _TEST_ATTRIBUTIONS + ) + yield predict_mock + + +@pytest.fixture +def predict_client_predict_list_mock(): + with mock.patch.object( + prediction_service_client.PredictionServiceClient, "predict" + ) as predict_mock: + predict_mock.return_value = gca_prediction_service.PredictResponse( + deployed_model_id=_TEST_ID + ) + predict_mock.return_value.predictions.extend(_TEST_LIST_PREDICTION) + yield predict_mock + + +@pytest.fixture +def predict_client_explain_list_mock(): + with mock.patch.object( + prediction_service_client.PredictionServiceClient, "explain" + ) as predict_mock: + predict_mock.return_value = gca_prediction_service.ExplainResponse( + deployed_model_id=_TEST_ID, + ) + predict_mock.return_value.predictions.extend(_TEST_LIST_PREDICTION) predict_mock.return_value.explanations.extend(_TEST_EXPLANATIONS) predict_mock.return_value.explanations[0].attributions.extend( _TEST_ATTRIBUTIONS @@ -312,10 +341,112 @@ def test_create_lit_model_from_tensorflow_with_xai_returns_model( assert len(item.values()) == 2 @pytest.mark.usefixtures( - "predict_client_predict_mock", "get_endpoint_with_models_mock" + "predict_client_predict_dict_mock", "get_endpoint_with_models_mock" + ) + @pytest.mark.parametrize("model_id", [None, _TEST_ID]) + def test_create_lit_model_from_dict_endpoint_returns_model( + self, feature_types, label_types, model_id + ): + endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME) + lit_model = create_lit_model_from_endpoint( + endpoint, feature_types, label_types, model_id + ) + test_inputs = [ + {"feature_1": 1.0, "feature_2": 2.0}, + ] + outputs = lit_model.predict_minibatch(test_inputs) + + assert lit_model.input_spec() == dict(feature_types) + assert lit_model.output_spec() == dict(label_types) + assert len(outputs) == 1 + for item in outputs: + assert item.keys() == {"label"} + assert len(item.values()) == 1 + + @pytest.mark.usefixtures( + "predict_client_explain_dict_mock", + "get_endpoint_with_models_with_explanation_mock", + ) + @pytest.mark.parametrize("model_id", [None, _TEST_ID]) + def test_create_lit_model_from_dict_endpoint_with_xai_returns_model( + self, feature_types, label_types, model_id + ): + endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME) + lit_model = create_lit_model_from_endpoint( + endpoint, feature_types, label_types, model_id + ) + test_inputs = [ + {"feature_1": 1.0, "feature_2": 2.0}, + ] + outputs = lit_model.predict_minibatch(test_inputs) + + assert lit_model.input_spec() == dict(feature_types) + assert lit_model.output_spec() == dict( + { + **label_types, + "feature_attribution": lit_types.FeatureSalience(signed=True), + } + ) + assert len(outputs) == 1 + for item in outputs: + assert item.keys() == {"label", "feature_attribution"} + assert len(item.values()) == 2 + + @pytest.mark.usefixtures( + "predict_client_predict_dict_mock", "get_endpoint_with_models_mock" + ) + @pytest.mark.parametrize("model_id", [None, _TEST_ID]) + def test_create_lit_model_from_dict_endpoint_name_returns_model( + self, feature_types, label_types, model_id + ): + lit_model = create_lit_model_from_endpoint( + _TEST_ENDPOINT_NAME, feature_types, label_types, model_id + ) + test_inputs = [ + {"feature_1": 1.0, "feature_2": 2.0}, + ] + outputs = lit_model.predict_minibatch(test_inputs) + + assert lit_model.input_spec() == dict(feature_types) + assert lit_model.output_spec() == dict(label_types) + assert len(outputs) == 1 + for item in outputs: + assert item.keys() == {"label"} + assert len(item.values()) == 1 + + @pytest.mark.usefixtures( + "predict_client_explain_dict_mock", + "get_endpoint_with_models_with_explanation_mock", + ) + @pytest.mark.parametrize("model_id", [None, _TEST_ID]) + def test_create_lit_model_from_dict_endpoint_name_with_xai_returns_model( + self, feature_types, label_types, model_id + ): + lit_model = create_lit_model_from_endpoint( + _TEST_ENDPOINT_NAME, feature_types, label_types, model_id + ) + test_inputs = [ + {"feature_1": 1.0, "feature_2": 2.0}, + ] + outputs = lit_model.predict_minibatch(test_inputs) + + assert lit_model.input_spec() == dict(feature_types) + assert lit_model.output_spec() == dict( + { + **label_types, + "feature_attribution": lit_types.FeatureSalience(signed=True), + } + ) + assert len(outputs) == 1 + for item in outputs: + assert item.keys() == {"label", "feature_attribution"} + assert len(item.values()) == 2 + + @pytest.mark.usefixtures( + "predict_client_predict_list_mock", "get_endpoint_with_models_mock" ) @pytest.mark.parametrize("model_id", [None, _TEST_ID]) - def test_create_lit_model_from_endpoint_returns_model( + def test_create_lit_model_from_list_endpoint_returns_model( self, feature_types, label_types, model_id ): endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME) @@ -335,10 +466,11 @@ def test_create_lit_model_from_endpoint_returns_model( assert len(item.values()) == 1 @pytest.mark.usefixtures( - "predict_client_explain_mock", "get_endpoint_with_models_with_explanation_mock" + "predict_client_explain_list_mock", + "get_endpoint_with_models_with_explanation_mock", ) @pytest.mark.parametrize("model_id", [None, _TEST_ID]) - def test_create_lit_model_from_endpoint_with_xai_returns_model( + def test_create_lit_model_from_list_endpoint_with_xai_returns_model( self, feature_types, label_types, model_id ): endpoint = aiplatform.Endpoint(_TEST_ENDPOINT_NAME) @@ -363,10 +495,10 @@ def test_create_lit_model_from_endpoint_with_xai_returns_model( assert len(item.values()) == 2 @pytest.mark.usefixtures( - "predict_client_predict_mock", "get_endpoint_with_models_mock" + "predict_client_predict_list_mock", "get_endpoint_with_models_mock" ) @pytest.mark.parametrize("model_id", [None, _TEST_ID]) - def test_create_lit_model_from_endpoint_name_returns_model( + def test_create_lit_model_from_list_endpoint_name_returns_model( self, feature_types, label_types, model_id ): lit_model = create_lit_model_from_endpoint( @@ -385,10 +517,11 @@ def test_create_lit_model_from_endpoint_name_returns_model( assert len(item.values()) == 1 @pytest.mark.usefixtures( - "predict_client_explain_mock", "get_endpoint_with_models_with_explanation_mock" + "predict_client_explain_list_mock", + "get_endpoint_with_models_with_explanation_mock", ) @pytest.mark.parametrize("model_id", [None, _TEST_ID]) - def test_create_lit_model_from_endpoint_name_with_xai_returns_model( + def test_create_lit_model_from_list_endpoint_name_with_xai_returns_model( self, feature_types, label_types, model_id ): lit_model = create_lit_model_from_endpoint(