Skip to content

Commit 3ebc929

Browse files
Adds multi vector query class (#402)
1 parent 66ef546 commit 3ebc929

File tree

8 files changed

+786
-12
lines changed

8 files changed

+786
-12
lines changed

docs/api/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Reference documentation for the RedisVL API.
1515
1616
schema
1717
searchindex
18+
vector
1819
query
1920
filter
2021
vectorizer

docs/api/query.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,17 @@ CountQuery
8888
:inherited-members:
8989
:show-inheritance:
9090
:exclude-members: add_filter,get_args,highlight,return_field,summarize
91+
92+
93+
94+
MultiVectorQuery
95+
==========
96+
97+
.. currentmodule:: redisvl.query
98+
99+
100+
.. autoclass:: MultiVectorQuery
101+
:members:
102+
:inherited-members:
103+
:show-inheritance:
104+
:exclude-members: add_filter,get_args,highlight,return_field,summarize

docs/api/vector.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
*****
3+
Vector
4+
*****
5+
6+
The Vector class in RedisVL is a container that encapsulates a numerical vector, it's datatype, corresponding index field name, and optional importance weight. It is used when constructing multi-vector queries using the MultiVectorQuery class.
7+
8+
9+
Vector
10+
===========
11+
12+
.. currentmodule:: redisvl.query
13+
14+
15+
.. autoclass:: Vector
16+
:members:
17+
:exclude-members:

redisvl/query/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from redisvl.query.aggregate import AggregationQuery, HybridQuery
1+
from redisvl.query.aggregate import (
2+
AggregationQuery,
3+
HybridQuery,
4+
MultiVectorQuery,
5+
Vector,
6+
)
27
from redisvl.query.query import (
38
BaseQuery,
49
BaseVectorQuery,
@@ -21,4 +26,6 @@
2126
"TextQuery",
2227
"AggregationQuery",
2328
"HybridQuery",
29+
"MultiVectorQuery",
30+
"Vector",
2431
]

redisvl/query/aggregate.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,41 @@
11
from typing import Any, Dict, List, Optional, Set, Tuple, Union
22

3+
from pydantic import BaseModel, field_validator
34
from redis.commands.search.aggregation import AggregateRequest, Desc
45

56
from redisvl.query.filter import FilterExpression
67
from redisvl.redis.utils import array_to_buffer
8+
from redisvl.schema.fields import VectorDataType
79
from redisvl.utils.token_escaper import TokenEscaper
810
from redisvl.utils.utils import lazy_import
911

1012
nltk = lazy_import("nltk")
1113
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
1214

1315

16+
class Vector(BaseModel):
17+
"""
18+
Simple object containing the necessary arguments to perform a multi vector query.
19+
"""
20+
21+
vector: Union[List[float], bytes]
22+
field_name: str
23+
dtype: str = "float32"
24+
weight: float = 1.0
25+
26+
@field_validator("dtype")
27+
@classmethod
28+
def validate_dtype(cls, dtype: str) -> str:
29+
try:
30+
VectorDataType(dtype.upper())
31+
except ValueError:
32+
raise ValueError(
33+
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
34+
)
35+
36+
return dtype
37+
38+
1439
class AggregationQuery(AggregateRequest):
1540
"""
1641
Base class for aggregation queries used to create aggregation queries for Redis.
@@ -227,3 +252,149 @@ def _build_query_string(self) -> str:
227252
def __str__(self) -> str:
228253
"""Return the string representation of the query."""
229254
return " ".join([str(x) for x in self.build_args()])
255+
256+
257+
class MultiVectorQuery(AggregationQuery):
258+
"""
259+
MultiVectorQuery allows for search over multiple vector fields in a document simulateously.
260+
The final score will be a weighted combination of the individual vector similarity scores
261+
following the formula:
262+
263+
score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... )
264+
265+
Vectors may be of different size and datatype, but must be indexed using the 'cosine' distance_metric.
266+
267+
.. code-block:: python
268+
269+
from redisvl.query import MultiVectorQuery, Vector
270+
from redisvl.index import SearchIndex
271+
272+
index = SearchIndex.from_yaml("path/to/index.yaml")
273+
274+
vector_1 = Vector(
275+
vector=[0.1, 0.2, 0.3],
276+
field_name="text_vector",
277+
dtype="float32",
278+
weight=0.7,
279+
)
280+
vector_2 = Vector(
281+
vector=[0.5, 0.5],
282+
field_name="image_vector",
283+
dtype="bfloat16",
284+
weight=0.2,
285+
)
286+
vector_3 = Vector(
287+
vector=[0.1, 0.2, 0.3],
288+
field_name="text_vector",
289+
dtype="float64",
290+
weight=0.5,
291+
)
292+
293+
query = MultiVectorQuery(
294+
vectors=[vector_1, vector_2, vector_3],
295+
filter_expression=None,
296+
num_results=10,
297+
return_fields=["field1", "field2"],
298+
dialect=2,
299+
)
300+
301+
results = index.query(query)
302+
"""
303+
304+
_vectors: List[Vector]
305+
306+
def __init__(
307+
self,
308+
vectors: Union[Vector, List[Vector]],
309+
return_fields: Optional[List[str]] = None,
310+
filter_expression: Optional[Union[str, FilterExpression]] = None,
311+
num_results: int = 10,
312+
dialect: int = 2,
313+
):
314+
"""
315+
Instantiates a MultiVectorQuery object.
316+
317+
Args:
318+
vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search.
319+
return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
320+
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use.
321+
Defaults to None.
322+
num_results (int, optional): The number of results to return. Defaults to 10.
323+
dialect (int, optional): The Redis dialect version. Defaults to 2.
324+
"""
325+
326+
self._filter_expression = filter_expression
327+
self._num_results = num_results
328+
329+
if isinstance(vectors, Vector):
330+
self._vectors = [vectors]
331+
else:
332+
self._vectors = vectors # type: ignore
333+
334+
if not all([isinstance(v, Vector) for v in self._vectors]):
335+
raise TypeError(
336+
"vector argument must be a Vector object or list of Vector objects."
337+
)
338+
339+
query_string = self._build_query_string()
340+
super().__init__(query_string)
341+
342+
# calculate the respective vector similarities
343+
for i in range(len(self._vectors)):
344+
self.apply(**{f"score_{i}": f"(2 - @distance_{i})/2"})
345+
346+
# construct the scoring string based on the vector similarity scores and weights
347+
combined_scores = []
348+
for i, w in enumerate([v.weight for v in self._vectors]):
349+
combined_scores.append(f"@score_{i} * {w}")
350+
combined_score_string = " + ".join(combined_scores)
351+
352+
self.apply(combined_score=combined_score_string)
353+
354+
self.sort_by(Desc("@combined_score"), max=num_results) # type: ignore
355+
self.dialect(dialect)
356+
if return_fields:
357+
self.load(*return_fields) # type: ignore[arg-type]
358+
359+
@property
360+
def params(self) -> Dict[str, Any]:
361+
"""Return the parameters for the aggregation.
362+
363+
Returns:
364+
Dict[str, Any]: The parameters for the aggregation.
365+
"""
366+
params = {}
367+
for i, (vector, dtype) in enumerate(
368+
[(v.vector, v.dtype) for v in self._vectors]
369+
):
370+
if isinstance(vector, list):
371+
vector = array_to_buffer(vector, dtype=dtype) # type: ignore
372+
params[f"vector_{i}"] = vector
373+
return params
374+
375+
def _build_query_string(self) -> str:
376+
"""Build the full query string for text search with optional filtering."""
377+
378+
# base KNN query
379+
range_queries = []
380+
for i, (vector, field) in enumerate(
381+
[(v.vector, v.field_name) for v in self._vectors]
382+
):
383+
range_queries.append(
384+
f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}"
385+
)
386+
387+
range_query = " | ".join(range_queries)
388+
389+
filter_expression = self._filter_expression
390+
if isinstance(self._filter_expression, FilterExpression):
391+
filter_expression = str(self._filter_expression)
392+
393+
if filter_expression:
394+
return f"({range_query}) AND ({filter_expression})"
395+
else:
396+
return f"{range_query}"
397+
398+
def __str__(self) -> str:
399+
"""Return the string representation of the query."""
400+
return " ".join([str(x) for x in self.build_args()])

tests/conftest.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,96 @@ def sample_data(sample_datetimes):
308308
]
309309

310310

311+
@pytest.fixture
312+
def multi_vector_data(sample_datetimes):
313+
return [
314+
{
315+
"user": "john",
316+
"age": 18,
317+
"job": "engineer",
318+
"description": "engineers conduct trains that ride on train tracks",
319+
"last_updated": sample_datetimes["low"].timestamp(),
320+
"credit_score": "high",
321+
"location": "-122.4194,37.7749",
322+
"user_embedding": [0.1, 0.1, 0.5],
323+
"image_embedding": [0.1, 0.1, 0.1, 0.1, 0.1],
324+
"audio_embedding": [34, 18.5, -6.0, -12, 115, 96.5],
325+
},
326+
{
327+
"user": "mary",
328+
"age": 14,
329+
"job": "doctor",
330+
"description": "a medical professional who treats diseases and helps people stay healthy",
331+
"last_updated": sample_datetimes["low"].timestamp(),
332+
"credit_score": "low",
333+
"location": "-122.4194,37.7749",
334+
"user_embedding": [0.1, 0.1, 0.5],
335+
"image_embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
336+
"audio_embedding": [0.0, -1.06, 4.55, -1.93, 0.0, 1.53],
337+
},
338+
{
339+
"user": "nancy",
340+
"age": 94,
341+
"job": "doctor",
342+
"description": "a research scientist specializing in cancers and diseases of the lungs",
343+
"last_updated": sample_datetimes["mid"].timestamp(),
344+
"credit_score": "high",
345+
"location": "-122.4194,37.7749",
346+
"user_embedding": [0.7, 0.1, 0.5],
347+
"image_embedding": [0.1, 0.1, 0.3, 0.3, 0.5],
348+
"audio_embedding": [2.75, -0.33, -3.01, -0.52, 5.59, -2.30],
349+
},
350+
{
351+
"user": "tyler",
352+
"age": 100,
353+
"job": "engineer",
354+
"description": "a software developer with expertise in mathematics and computer science",
355+
"last_updated": sample_datetimes["mid"].timestamp(),
356+
"credit_score": "high",
357+
"location": "-110.0839,37.3861",
358+
"user_embedding": [0.1, 0.4, 0.5],
359+
"image_embedding": [-0.1, -0.2, -0.3, -0.4, -0.5],
360+
"audio_embedding": [1.11, -6.73, 5.41, 1.04, 3.92, 0.73],
361+
},
362+
{
363+
"user": "tim",
364+
"age": 12,
365+
"job": "dermatologist",
366+
"description": "a medical professional specializing in diseases of the skin",
367+
"last_updated": sample_datetimes["mid"].timestamp(),
368+
"credit_score": "high",
369+
"location": "-110.0839,37.3861",
370+
"user_embedding": [0.4, 0.4, 0.5],
371+
"image_embedding": [-0.1, 0.0, 0.6, 0.0, -0.9],
372+
"audio_embedding": [0.03, -2.67, -2.08, 4.57, -2.33, 0.0],
373+
},
374+
{
375+
"user": "taimur",
376+
"age": 15,
377+
"job": "CEO",
378+
"description": "high stress, but financially rewarding position at the head of a company",
379+
"last_updated": sample_datetimes["high"].timestamp(),
380+
"credit_score": "low",
381+
"location": "-110.0839,37.3861",
382+
"user_embedding": [0.6, 0.1, 0.5],
383+
"image_embedding": [1.1, 1.2, -0.3, -4.1, 5.0],
384+
"audio_embedding": [0.68, 0.26, 2.08, 2.96, 0.01, 5.13],
385+
},
386+
{
387+
"user": "joe",
388+
"age": 35,
389+
"job": "dentist",
390+
"description": "like the tooth fairy because they'll take your teeth, but you have to pay them!",
391+
"last_updated": sample_datetimes["high"].timestamp(),
392+
"credit_score": "medium",
393+
"location": "-110.0839,37.3861",
394+
"user_embedding": [-0.1, -0.1, -0.5],
395+
"image_embedding": [-0.8, 2.0, 3.1, 1.5, -1.6],
396+
"audio_embedding": [0.91, 7.10, -2.14, -0.52, -6.08, -5.53],
397+
},
398+
]
399+
400+
311401
def pytest_addoption(parser: pytest.Parser) -> None:
312402
parser.addoption(
313403
"--run-api-tests",

0 commit comments

Comments
 (0)