Skip to content

Commit

Permalink
[Bug] query_namespaces can handle single result (#421)
Browse files Browse the repository at this point in the history
## Problem

In order to merge results across multiple queries, the SDK must know
which similarity metric an index is using. For dotproduct and cosine
indexes, a larger score is better while for euclidean a smaller score is
better. Unfortunately the data plane API does not currently expose the
metric type and a separate call to the control plane to find out seems
undesirable from a resiliency and performance perspective.

As a workaround, in the initial implementation of `query_namespaces` the
SDK would infer the similarity metric needed to merge results by seeing
whether the scores of query results were ascending or descending. This
worked well, but imposes an implicit limitation that there must be at
least 2 results returned.

We initially believed this would not be a problem but have since learned
that applications using filtering can sometimes filter out all or most
results. So an approach that has the user explicitly telling the SDK
what similarity metric is being used is preferred to handle these edge
cases with 1 or 0 results.

## Solution

- Add a required kwarg to `query_namespaces` to specify the index
similarity metric.
- Modify `QueryResultsAggregator` to use this similarity metric, and
strip out code that was involved in inferring whether results were
ascending or descending.
- Adjust integration tests to pass new metric kwarg. Except for adding
the new kwarg, query_namespaces integration tests did not need to change
which indicates the underlying behavior is still working as before.

## Type of Change

- [x] Bug fix (non-breaking change which fixes an issue)
  • Loading branch information
jhamon authored Dec 6, 2024
1 parent ec9afe1 commit 5453aab
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 182 deletions.
7 changes: 5 additions & 2 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import json
from typing import Union, List, Optional, Dict, Any
from typing import Union, List, Optional, Dict, Any, Literal

from pinecone.config import ConfigBuilder

Expand Down Expand Up @@ -511,6 +511,7 @@ def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
metric: Literal["cosine", "euclidean", "dotproduct"],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
Expand Down Expand Up @@ -540,6 +541,7 @@ def query_namespaces(
combined_results = index.query_namespaces(
vector=query_vec,
namespaces=['ns1', 'ns2', 'ns3', 'ns4'],
metric="cosine",
top_k=10,
filter={'genre': {"$eq": "drama"}},
include_values=True,
Expand All @@ -554,6 +556,7 @@ def query_namespaces(
vector (List[float]): The query vector, must be the same length as the dimension of the index being queried.
namespaces (List[str]): The list of namespaces to query.
top_k (Optional[int], optional): The number of results you would like to request from each namespace. Defaults to 10.
metric (str): Must be one of 'cosine', 'euclidean', 'dotproduct'. This is needed in order to merge results across namespaces, since the interpretation of score depends on the index metric type.
filter (Optional[Dict[str, Union[str, float, int, bool, List, dict]]], optional): Pass an optional filter to filter results based on metadata. Defaults to None.
include_values (Optional[bool], optional): Boolean field indicating whether vector values should be included with results. Defaults to None.
include_metadata (Optional[bool], optional): Boolean field indicating whether vector metadata should be included with results. Defaults to None.
Expand All @@ -568,7 +571,7 @@ def query_namespaces(
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)
aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric)

target_namespaces = set(namespaces) # dedup namespaces
async_futures = [
Expand Down
62 changes: 23 additions & 39 deletions pinecone/data/query_results_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Optional, Any, Dict
from typing import List, Tuple, Optional, Any, Dict, Literal
import json
import heapq
from pinecone.core.openapi.data.models import Usage
Expand Down Expand Up @@ -88,46 +88,38 @@ def __repr__(self):
)


class QueryResultsAggregregatorNotEnoughResultsError(Exception):
def __init__(self):
super().__init__(
"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
)


class QueryResultsAggregatorInvalidTopKError(Exception):
def __init__(self, top_k: int):
super().__init__(
f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
)
super().__init__(f"Invalid top_k value {top_k}. top_k must be at least 1.")


class QueryResultsAggregator:
def __init__(self, top_k: int):
if top_k < 2:
def __init__(self, top_k: int, metric: Literal["cosine", "euclidean", "dotproduct"]):
if top_k < 1:
raise QueryResultsAggregatorInvalidTopKError(top_k)

if metric in ["dotproduct", "cosine"]:
self.is_bigger_better = True
elif metric in ["euclidean"]:
self.is_bigger_better = False
else:
raise ValueError(
f"Cannot merge results for unknown similarity metric {metric}. Supported metrics are 'dotproduct', 'cosine', and 'euclidean'."
)

self.top_k = top_k
self.usage_read_units = 0
self.heap: List[Tuple[float, int, object, str]] = []
self.insertion_counter = 0
self.is_dotproduct = None
self.read = False
self.final_results: Optional[QueryNamespacesResults] = None

def _is_dotproduct_index(self, matches):
# The interpretation of the score depends on the similar metric used.
# Unlike other index types, in indexes configured for dotproduct,
# a higher score is better. We have to infer this is the case by inspecting
# the order of the scores in the results.
for i in range(1, len(matches)):
if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase
return False
return True

def _dotproduct_heap_item(self, match, ns):
def _bigger_better_heap_item(self, match, ns):
# This 4-tuple is used to ensure that the heap is sorted by score followed by
# insertion order. The insertion order is used to break any ties in the score.
return (match.get("score"), -self.insertion_counter, match, ns)

def _non_dotproduct_heap_item(self, match, ns):
def _smaller_better_heap_item(self, match, ns):
return (-match.get("score"), -self.insertion_counter, match, ns)

def _process_matches(self, matches, ns, heap_item_fn):
Expand All @@ -137,10 +129,10 @@ def _process_matches(self, matches, ns, heap_item_fn):
heapq.heappush(self.heap, heap_item_fn(match, ns))
else:
# Assume we have dotproduct scores sorted in descending order
if self.is_dotproduct and match["score"] < self.heap[0][0]:
if self.is_bigger_better and match["score"] < self.heap[0][0]:
# No further matches can improve the top-K heap
break
elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
elif not self.is_bigger_better and match["score"] > -self.heap[0][0]:
# No further matches can improve the top-K heap
break
heapq.heappushpop(self.heap, heap_item_fn(match, ns))
Expand All @@ -162,18 +154,10 @@ def add_results(self, results: Dict[str, Any]):
if len(matches) == 0:
return

if self.is_dotproduct is None:
if len(matches) == 1:
# This condition should match the second time we add results containing
# only one match. We need at least two matches in a single response in order
# to infer the similarity metric
raise QueryResultsAggregregatorNotEnoughResultsError()
self.is_dotproduct = self._is_dotproduct_index(matches)

if self.is_dotproduct:
self._process_matches(matches, ns, self._dotproduct_heap_item)
if self.is_bigger_better:
self._process_matches(matches, ns, self._bigger_better_heap_item)
else:
self._process_matches(matches, ns, self._non_dotproduct_heap_item)
self._process_matches(matches, ns, self._smaller_better_heap_item)

def get_results(self) -> QueryNamespacesResults:
if self.read:
Expand Down
5 changes: 3 additions & 2 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast, Literal

from google.protobuf import json_format

Expand Down Expand Up @@ -409,6 +409,7 @@ def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
metric: Literal["cosine", "euclidean", "dotproduct"],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
Expand All @@ -422,7 +423,7 @@ def query_namespaces(
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)
aggregator = QueryResultsAggregator(top_k=overall_topk, metric=metric)

target_namespaces = set(namespaces) # dedup namespaces
futures = [
Expand Down
69 changes: 35 additions & 34 deletions tests/integration/data/test_query_namespaces.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import pytest
from ..helpers import random_string, poll_stats_for_namespace
from pinecone.data.query_results_aggregator import (
QueryResultsAggregatorInvalidTopKError,
QueryResultsAggregregatorNotEnoughResultsError,
)

from pinecone import Vector


class TestQueryNamespacesRest:
def test_query_namespaces(self, idx):
def test_query_namespaces(self, idx, metric):
ns_prefix = random_string(5)
ns1 = f"{ns_prefix}-ns1"
ns2 = f"{ns_prefix}-ns2"
Expand Down Expand Up @@ -50,6 +46,7 @@ def test_query_namespaces(self, idx):
results = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "drama"}},
Expand Down Expand Up @@ -84,6 +81,7 @@ def test_query_namespaces(self, idx):
results2 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3, f"{ns_prefix}-nonexistent"],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "action"}},
Expand All @@ -98,6 +96,7 @@ def test_query_namespaces(self, idx):
results3 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={},
Expand All @@ -110,6 +109,7 @@ def test_query_namespaces(self, idx):
results4 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "comedy"}},
Expand All @@ -122,6 +122,7 @@ def test_query_namespaces(self, idx):
results5 = idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2, ns3],
metric=metric,
include_values=True,
include_metadata=True,
filter={},
Expand All @@ -137,6 +138,7 @@ def test_query_namespaces(self, idx):
f"{ns_prefix}-nonexistent2",
f"{ns_prefix}-nonexistent3",
],
metric=metric,
include_values=True,
include_metadata=True,
filter={"genre": {"$eq": "comedy"}},
Expand All @@ -145,22 +147,7 @@ def test_query_namespaces(self, idx):
assert len(results6.matches) == 0
assert results6.usage.read_units > 0

def test_invalid_top_k(self, idx):
with pytest.raises(QueryResultsAggregatorInvalidTopKError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=["ns1", "ns2", "ns3"],
include_values=True,
include_metadata=True,
filter={},
top_k=1,
)
assert (
str(e.value)
== "Invalid top_k value 1. To aggregate results from multiple queries the top_k must be at least 2."
)

def test_unmergeable_results(self, idx):
def test_single_result_per_namespace(self, idx):
ns_prefix = random_string(5)
ns1 = f"{ns_prefix}-ns1"
ns2 = f"{ns_prefix}-ns2"
Expand All @@ -183,26 +170,27 @@ def test_unmergeable_results(self, idx):
poll_stats_for_namespace(idx, namespace=ns1, expected_count=2)
poll_stats_for_namespace(idx, namespace=ns2, expected_count=2)

with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[ns1, ns2],
include_values=True,
include_metadata=True,
filter={"key": {"$eq": 1}},
top_k=2,
)

assert (
str(e.value)
== "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
results = idx.query_namespaces(
vector=[0.1, 0.21],
namespaces=[ns1, ns2],
metric="cosine",
include_values=True,
include_metadata=True,
filter={"key": {"$eq": 1}},
top_k=2,
)
assert len(results.matches) == 2
assert results.matches[0].id == "id1"
assert results.matches[0].namespace == ns1
assert results.matches[1].id == "id5"
assert results.matches[1].namespace == ns2

def test_missing_namespaces(self, idx):
with pytest.raises(ValueError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=[],
metric="cosine",
include_values=True,
include_metadata=True,
filter={},
Expand All @@ -214,9 +202,22 @@ def test_missing_namespaces(self, idx):
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=None,
metric="cosine",
include_values=True,
include_metadata=True,
filter={},
top_k=2,
)
assert str(e.value) == "At least one namespace must be specified"

def test_missing_metric(self, idx):
with pytest.raises(TypeError) as e:
idx.query_namespaces(
vector=[0.1, 0.2],
namespaces=["ns1"],
include_values=True,
include_metadata=True,
filter={},
top_k=2,
)
assert "query_namespaces() missing 1 required positional argument: 'metric'" in str(e.value)
Loading

0 comments on commit 5453aab

Please sign in to comment.