diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index ac0bb0cc00..c726b1c96c 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -216,6 +216,7 @@ def __init__( ) self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name) + self._public_match_client = None if self.public_endpoint_domain_name: self._public_match_client = self._instantiate_public_match_client() @@ -518,6 +519,36 @@ def _instantiate_public_match_client( api_path_override=self.public_endpoint_domain_name, ) + def _instantiate_private_match_service_stub( + self, + deployed_index_id: str, + ) -> match_service_pb2_grpc.MatchServiceStub: + """Helper method to instantiate private match service stub. + Args: + deployed_index_id (str): + Required. The user specified ID of the + DeployedIndex. + Returns: + stub (match_service_pb2_grpc.MatchServiceStub): + Initialized match service stub. + """ + # Find the deployed index by id + deployed_indexes = [ + deployed_index + for deployed_index in self.deployed_indexes + if deployed_index.id == deployed_index_id + ] + + if not deployed_indexes: + raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found") + + # Retrieve server ip from deployed index + server_ip = deployed_indexes[0].private_endpoints.match_grpc_address + + # Set up channel and stub + channel = grpc.insecure_channel("{}:10000".format(server_ip)) + return match_service_pb2_grpc.MatchServiceStub(channel) + @property def public_endpoint_domain_name(self) -> Optional[str]: """Public endpoint DNS name.""" @@ -1233,7 +1264,8 @@ def read_index_datapoints( deployed_index_id: str, ids: List[str] = [], ) -> List[gca_index_v1beta1.IndexDatapoint]: - """Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint. + """Reads the datapoints/vectors of the given IDs on the specified + deployed index which is deployed to public or private endpoint. ``` Example Usage: @@ -1252,9 +1284,25 @@ def read_index_datapoints( List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs. """ if not self._public_match_client: - raise ValueError( - "Please make sure index has been deployed to public endpoint, and follow the example usage to call this method." + # Call private match service stub with BatchGetEmbeddings request + response = self._batch_get_embeddings( + deployed_index_id=deployed_index_id, ids=ids ) + return [ + gca_index_v1beta1.IndexDatapoint( + datapoint_id=embedding.id, + feature_vector=embedding.float_val, + restricts=gca_index_v1beta1.IndexDatapoint.Restriction( + namespace=embedding.restricts.name, + allow_list=embedding.restricts.allow_tokens, + ), + deny_list=embedding.restricts.deny_tokens, + crowding_attributes=gca_index_v1beta1.CrowdingEmbedding( + str(embedding.crowding_tag) + ), + ) + for embedding in response.embeddings + ] # Create the ReadIndexDatapoints request read_index_datapoints_request = ( @@ -1273,6 +1321,38 @@ def read_index_datapoints( # Wrap the results and return return response.datapoints + def _batch_get_embeddings( + self, + *, + deployed_index_id: str, + ids: List[str] = [], + ) -> List[List[match_service_pb2.Embedding]]: + """ + Reads the datapoints/vectors of the given IDs on the specified index + which is deployed to private endpoint. + + Args: + deployed_index_id (str): + Required. The ID of the DeployedIndex to match the queries against. + ids (List[str]): + Required. IDs of the datapoints to be searched for. + Returns: + List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs. + """ + stub = self._instantiate_private_match_service_stub( + deployed_index_id=deployed_index_id + ) + + # Create the batch get embeddings request + batch_request = match_service_pb2.BatchGetEmbeddingsRequest() + batch_request.deployed_index_id = deployed_index_id + + for id in ids: + batch_request.id.append(id) + response = stub.BatchGetEmbeddings(batch_request) + + return response.embeddings + def match( self, deployed_index_id: str, @@ -1310,23 +1390,9 @@ def match( Returns: List[List[MatchNeighbor]] - A list of nearest neighbors for each query. """ - - # Find the deployed index by id - deployed_indexes = [ - deployed_index - for deployed_index in self.deployed_indexes - if deployed_index.id == deployed_index_id - ] - - if not deployed_indexes: - raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found") - - # Retrieve server ip from deployed index - server_ip = deployed_indexes[0].private_endpoints.match_grpc_address - - # Set up channel and stub - channel = grpc.insecure_channel("{}:10000".format(server_ip)) - stub = match_service_pb2_grpc.MatchServiceStub(channel) + stub = self._instantiate_private_match_service_stub( + deployed_index_id=deployed_index_id + ) # Create the batch match request batch_request = match_service_pb2.BatchMatchRequest() diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 581a584c4e..5aac0a0b3e 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -493,6 +493,31 @@ def index_endpoint_match_queries_mock(): yield index_endpoint_match_queries_mock +@pytest.fixture +def index_endpoint_batch_get_embeddings_mock(): + with patch.object( + grpc._channel._UnaryUnaryMultiCallable, + "__call__", + ) as index_endpoint_batch_get_embeddings_mock: + index_endpoint_batch_get_embeddings_mock.return_value = ( + match_service_pb2.BatchGetEmbeddingsResponse( + embeddings=[ + match_service_pb2.Embedding( + id="1", + float_val=[0.1, 0.2, 0.3], + crowding_attribute=1, + ), + match_service_pb2.Embedding( + id="2", + float_val=[0.5, 0.2, 0.3], + crowding_attribute=1, + ), + ] + ) + ) + yield index_endpoint_batch_get_embeddings_mock + + @pytest.fixture def index_public_endpoint_match_queries_mock(): with patch.object( @@ -1204,3 +1229,23 @@ def test_index_public_endpoint_read_index_datapoints( index_public_endpoint_read_index_datapoints_mock.assert_called_with( read_index_datapoints_request ) + + @pytest.mark.usefixtures("get_index_endpoint_mock") + def test_index_endpoint_batch_get_embeddings( + self, index_endpoint_batch_get_embeddings_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_index_endpoint._batch_get_embeddings( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, ids=["1", "2"] + ) + + batch_request = match_service_pb2.BatchGetEmbeddingsRequest( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"] + ) + + index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)