Skip to content

Commit

Permalink
feat: Add hybrid search for public find_neighbors() call.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 640750317
  • Loading branch information
lingyinw authored and copybara-github committed Jun 6, 2024
1 parent c118557 commit 9d35617
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence, Tuple, Union

from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
Expand Down Expand Up @@ -148,6 +148,37 @@ def __post_init__(self):
)


@dataclass
class HybridQuery:
"""
Hyrbid query. Could be used for dense-only or sparse-only or hybrid queries.
dense_embedding (List[float]):
Optional. The dense part of the hybrid queries.
sparse_embedding_values (List[float]):
Optional. The sparse values of the sparse part of the queries.
sparse_embedding_dimensions (List[int]):
Optional. The corresponding dimensions of the sparse values.
For example, values [1,2,3] with dimensions [4,5,6] means value 1 is of the
4th dimension, value 2 is of the 4th dimension, and value 3 is of the 6th
dimension.
rrf_ranking_alpha (float):
Optional. This should not be specified for dense-only or sparse-only queries.
A value between 0 and 1 for ranking algorithm RRF, representing
the ratio for sparse v.s. dense embeddings returned in the query result.
If the alpha is 0, only sparse embeddings are being returned, and no dense
embedding is being returned. When alhpa is 1, only dense embeddings are being
returned, and no sparse embedding is being returned.
"""

dense_embedding: List[float] = None
sparse_embedding_values: List[float] = None
sparse_embedding_dimensions: List[int] = None
rrf_ranking_alpha: float = None


@dataclass
class MatchNeighbor:
"""The id and distance of a nearest neighbor match for a given query embedding.
Expand All @@ -157,7 +188,7 @@ class MatchNeighbor:
Required. The id of the neighbor.
distance (float):
Required. The distance to the query embedding.
feature_vector (List(float)):
feature_vector (List[float]):
Optional. The feature vector of the matching datapoint.
crowding_tag (Optional[str]):
Optional. Crowding tag of the datapoint, the
Expand All @@ -167,6 +198,14 @@ class MatchNeighbor:
Optional. The restricts of the matching datapoint.
numeric_restricts:
Optional. The numeric restricts of the matching datapoint.
sparse_embedding_values (List[float]):
Optional. The sparse values of the sparse part of the matching
datapoint.
sparse_embedding_dimensions (List[int]):
Optional. The corresponding dimensions of the sparse values.
For example, values [1,2,3] with dimensions [4,5,6] means value 1 is
of the 4th dimension, value 2 is of the 4th dimension, and value 3 is
of the 6th dimension.
"""

Expand All @@ -176,6 +215,8 @@ class MatchNeighbor:
crowding_tag: Optional[str] = None
restricts: Optional[List[Namespace]] = None
numeric_restricts: Optional[List[NumericNamespace]] = None
sparse_embedding_values: Optional[List[float]] = None
sparse_embedding_dimensions: Optional[List[int]] = None

def from_index_datapoint(
self, index_datapoint: gca_index_v1beta1.IndexDatapoint
Expand Down Expand Up @@ -207,22 +248,31 @@ def from_index_datapoint(
]
if index_datapoint.numeric_restricts is not None:
self.numeric_restricts = []
for restrict in index_datapoint.numeric_restricts:
numeric_namespace = None
restrict_value_type = restrict._pb.WhichOneof("Value")
if restrict_value_type == "value_int":
numeric_namespace = NumericNamespace(
name=restrict.namespace, value_int=restrict.value_int
)
elif restrict_value_type == "value_float":
numeric_namespace = NumericNamespace(
name=restrict.namespace, value_float=restrict.value_float
)
elif restrict_value_type == "value_double":
numeric_namespace = NumericNamespace(
name=restrict.namespace, value_double=restrict.value_double
)
self.numeric_restricts.append(numeric_namespace)
for restrict in index_datapoint.numeric_restricts:
numeric_namespace = None
restrict_value_type = restrict._pb.WhichOneof("Value")
if restrict_value_type == "value_int":
numeric_namespace = NumericNamespace(
name=restrict.namespace, value_int=restrict.value_int
)
elif restrict_value_type == "value_float":
numeric_namespace = NumericNamespace(
name=restrict.namespace, value_float=restrict.value_float
)
elif restrict_value_type == "value_double":
numeric_namespace = NumericNamespace(
name=restrict.namespace, value_double=restrict.value_double
)
self.numeric_restricts.append(numeric_namespace)
# sparse embeddings
if (
index_datapoint.sparse_embedding is not None
and index_datapoint.sparse_embedding.values is not None
):
self.sparse_embedding_values = index_datapoint.sparse_embedding.values
self.sparse_embedding_dimensions = (
index_datapoint.sparse_embedding.dimensions
)
return self

def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighbor":
Expand Down Expand Up @@ -250,22 +300,22 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb
]
if embedding.numeric_restricts:
self.numeric_restricts = []
for restrict in embedding.numeric_restricts:
numeric_namespace = None
restrict_value_type = restrict.WhichOneof("Value")
if restrict_value_type == "value_int":
numeric_namespace = NumericNamespace(
name=restrict.name, value_int=restrict.value_int
)
elif restrict_value_type == "value_float":
numeric_namespace = NumericNamespace(
name=restrict.name, value_float=restrict.value_float
)
elif restrict_value_type == "value_double":
numeric_namespace = NumericNamespace(
name=restrict.name, value_double=restrict.value_double
)
self.numeric_restricts.append(numeric_namespace)
for restrict in embedding.numeric_restricts:
numeric_namespace = None
restrict_value_type = restrict.WhichOneof("Value")
if restrict_value_type == "value_int":
numeric_namespace = NumericNamespace(
name=restrict.name, value_int=restrict.value_int
)
elif restrict_value_type == "value_float":
numeric_namespace = NumericNamespace(
name=restrict.name, value_float=restrict.value_float
)
elif restrict_value_type == "value_double":
numeric_namespace = NumericNamespace(
name=restrict.name, value_double=restrict.value_double
)
self.numeric_restricts.append(numeric_namespace)
return self


Expand Down Expand Up @@ -1322,7 +1372,7 @@ def find_neighbors(
self,
*,
deployed_index_id: str,
queries: Optional[List[List[float]]] = None,
queries: Optional[Union[List[List[float]], List[HybridQuery]]] = None,
num_neighbors: int = 10,
filter: Optional[List[Namespace]] = None,
per_crowding_attribute_neighbor_count: Optional[int] = None,
Expand All @@ -1346,8 +1396,15 @@ def find_neighbors(
Args:
deployed_index_id (str):
Required. The ID of the DeployedIndex to match the queries against.
queries (List[List[float]]):
Required. A list of queries. Each query is a list of floats, representing a single embedding.
queries (Union[List[List[float]], List[HybridQuery]]):
Optional. A list of queries.
For regular dense-only queries, each query is a list of floats,
representing a single embedding.
For hybrid queries, each query is a hybrid query of type
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery.
num_neighbors (int):
Required. The number of nearest neighbors to be retrieved from database for
each query.
Expand Down Expand Up @@ -1381,7 +1438,7 @@ def find_neighbors(
Note that returning full datapoint will significantly increase the
latency and cost of the query.
numeric_filter (list[NumericNamespace]):
numeric_filter (List[NumericNamespace]):
Optional. A list of NumericNamespaces for filtering the matching
results. For example:
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
Expand Down Expand Up @@ -1437,30 +1494,54 @@ def find_neighbors(
numeric_restrict.value_double = numeric_namespace.value_double
numeric_restricts.append(numeric_restrict)
# Queries
query_by_id = False if queries else True
queries = queries if queries else embedding_ids
if queries:
for query in queries:
find_neighbors_query = gca_match_service_v1beta1.FindNeighborsRequest.Query(
neighbor_count=num_neighbors,
per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count,
approximate_neighbor_count=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
)
datapoint = gca_index_v1beta1.IndexDatapoint(
datapoint_id=query if query_by_id else None,
feature_vector=None if query_by_id else query,
)
datapoint.restricts.extend(restricts)
datapoint.numeric_restricts.extend(numeric_restricts)
find_neighbors_query.datapoint = datapoint
find_neighbors_request.queries.append(find_neighbors_query)
query_by_id = False
query_is_hybrid = False
if embedding_ids:
query_by_id = True
query_iterators: list[str] = embedding_ids
elif queries:
query_is_hybrid = isinstance(queries[0], HybridQuery)
query_iterators = queries
else:
raise ValueError(
"To find neighbors using matching engine,"
"please specify `queries` or `embedding_ids`"
"please specify `queries` or `embedding_ids` or `hybrid_queries`"
)

for query in query_iterators:
find_neighbors_query = gca_match_service_v1beta1.FindNeighborsRequest.Query(
neighbor_count=num_neighbors,
per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count,
approximate_neighbor_count=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
)
if query_by_id:
datapoint = gca_index_v1beta1.IndexDatapoint(
datapoint_id=query,
)
elif query_is_hybrid:
datapoint = gca_index_v1beta1.IndexDatapoint(
feature_vector=query.dense_embedding,
sparse_embedding=gca_index_v1beta1.IndexDatapoint.SparseEmbedding(
values=query.sparse_embedding_values,
dimensions=query.sparse_embedding_dimensions,
),
)
if query.rrf_ranking_alpha:
find_neighbors_query.rrf = (
gca_match_service_v1beta1.FindNeighborsRequest.Query.RRF(
alpha=query.rrf_ranking_alpha,
)
)
else:
datapoint = gca_index_v1beta1.IndexDatapoint(
feature_vector=query,
)
datapoint.restricts.extend(restricts)
datapoint.numeric_restricts.extend(numeric_restricts)
find_neighbors_query.datapoint = datapoint
find_neighbors_request.queries.append(find_neighbors_query)

response = self._public_match_client.find_neighbors(find_neighbors_request)

# Wrap the results in MatchNeighbor objects and return
Expand Down Expand Up @@ -1543,7 +1624,6 @@ def read_index_datapoints(
read_index_datapoints_request
)

# Wrap the results and return
return response.datapoints

def _batch_get_embeddings(
Expand Down
Loading

0 comments on commit 9d35617

Please sign in to comment.