Skip to content

Commit

Permalink
feat: support read_index_datapoints for private network.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589281257
  • Loading branch information
lingyinw authored and copybara-github committed Dec 9, 2023
1 parent a8b24ad commit c9f7119
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def __init__(
)
self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name)

self._public_match_client = None
if self.public_endpoint_domain_name:
self._public_match_client = self._instantiate_public_match_client()

Expand Down Expand Up @@ -518,6 +519,36 @@ def _instantiate_public_match_client(
api_path_override=self.public_endpoint_domain_name,
)

def _instantiate_private_match_service_stub(
self,
deployed_index_id: str,
) -> match_service_pb2_grpc.MatchServiceStub:
"""Helper method to instantiate private match service stub.
Args:
deployed_index_id (str):
Required. The user specified ID of the
DeployedIndex.
Returns:
stub (match_service_pb2_grpc.MatchServiceStub):
Initialized match service stub.
"""
# Find the deployed index by id
deployed_indexes = [
deployed_index
for deployed_index in self.deployed_indexes
if deployed_index.id == deployed_index_id
]

if not deployed_indexes:
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")

# Retrieve server ip from deployed index
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address

# Set up channel and stub
channel = grpc.insecure_channel("{}:10000".format(server_ip))
return match_service_pb2_grpc.MatchServiceStub(channel)

@property
def public_endpoint_domain_name(self) -> Optional[str]:
"""Public endpoint DNS name."""
Expand Down Expand Up @@ -1233,7 +1264,8 @@ def read_index_datapoints(
deployed_index_id: str,
ids: List[str] = [],
) -> List[gca_index_v1beta1.IndexDatapoint]:
"""Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public endpoint.
"""Reads the datapoints/vectors of the given IDs on the specified
deployed index which is deployed to public or private endpoint.
```
Example Usage:
Expand All @@ -1252,9 +1284,25 @@ def read_index_datapoints(
List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs.
"""
if not self._public_match_client:
raise ValueError(
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
# Call private match service stub with BatchGetEmbeddings request
response = self._batch_get_embeddings(
deployed_index_id=deployed_index_id, ids=ids
)
return [
gca_index_v1beta1.IndexDatapoint(
datapoint_id=embedding.id,
feature_vector=embedding.float_val,
restricts=gca_index_v1beta1.IndexDatapoint.Restriction(
namespace=embedding.restricts.name,
allow_list=embedding.restricts.allow_tokens,
),
deny_list=embedding.restricts.deny_tokens,
crowding_attributes=gca_index_v1beta1.CrowdingEmbedding(
str(embedding.crowding_tag)
),
)
for embedding in response.embeddings
]

# Create the ReadIndexDatapoints request
read_index_datapoints_request = (
Expand All @@ -1273,6 +1321,38 @@ def read_index_datapoints(
# Wrap the results and return
return response.datapoints

def _batch_get_embeddings(
self,
*,
deployed_index_id: str,
ids: List[str] = [],
) -> List[List[match_service_pb2.Embedding]]:
"""
Reads the datapoints/vectors of the given IDs on the specified index
which is deployed to private endpoint.
Args:
deployed_index_id (str):
Required. The ID of the DeployedIndex to match the queries against.
ids (List[str]):
Required. IDs of the datapoints to be searched for.
Returns:
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
"""
stub = self._instantiate_private_match_service_stub(
deployed_index_id=deployed_index_id
)

# Create the batch get embeddings request
batch_request = match_service_pb2.BatchGetEmbeddingsRequest()
batch_request.deployed_index_id = deployed_index_id

for id in ids:
batch_request.id.append(id)
response = stub.BatchGetEmbeddings(batch_request)

return response.embeddings

def match(
self,
deployed_index_id: str,
Expand Down Expand Up @@ -1310,23 +1390,9 @@ def match(
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""

# Find the deployed index by id
deployed_indexes = [
deployed_index
for deployed_index in self.deployed_indexes
if deployed_index.id == deployed_index_id
]

if not deployed_indexes:
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")

# Retrieve server ip from deployed index
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address

# Set up channel and stub
channel = grpc.insecure_channel("{}:10000".format(server_ip))
stub = match_service_pb2_grpc.MatchServiceStub(channel)
stub = self._instantiate_private_match_service_stub(
deployed_index_id=deployed_index_id
)

# Create the batch match request
batch_request = match_service_pb2.BatchMatchRequest()
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,31 @@ def index_endpoint_match_queries_mock():
yield index_endpoint_match_queries_mock


@pytest.fixture
def index_endpoint_batch_get_embeddings_mock():
with patch.object(
grpc._channel._UnaryUnaryMultiCallable,
"__call__",
) as index_endpoint_batch_get_embeddings_mock:
index_endpoint_batch_get_embeddings_mock.return_value = (
match_service_pb2.BatchGetEmbeddingsResponse(
embeddings=[
match_service_pb2.Embedding(
id="1",
float_val=[0.1, 0.2, 0.3],
crowding_attribute=1,
),
match_service_pb2.Embedding(
id="2",
float_val=[0.5, 0.2, 0.3],
crowding_attribute=1,
),
]
)
)
yield index_endpoint_batch_get_embeddings_mock


@pytest.fixture
def index_public_endpoint_match_queries_mock():
with patch.object(
Expand Down Expand Up @@ -1204,3 +1229,23 @@ def test_index_public_endpoint_read_index_datapoints(
index_public_endpoint_read_index_datapoints_mock.assert_called_with(
read_index_datapoints_request
)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_index_endpoint_batch_get_embeddings(
self, index_endpoint_batch_get_embeddings_mock
):
aiplatform.init(project=_TEST_PROJECT)

my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_index_endpoint._batch_get_embeddings(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, ids=["1", "2"]
)

batch_request = match_service_pb2.BatchGetEmbeddingsRequest(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
)

index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)

0 comments on commit c9f7119

Please sign in to comment.