Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions redisvl/query/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, model_validator
from redis.commands.search.aggregation import AggregateRequest, Desc
from typing_extensions import Self

from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.redis.utils import array_to_buffer, buffer_to_array
from redisvl.schema.fields import VectorDataType
from redisvl.utils.token_escaper import TokenEscaper
from redisvl.utils.utils import lazy_import
Expand Down Expand Up @@ -32,9 +33,16 @@ def validate_dtype(cls, dtype: str) -> str:
raise ValueError(
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
)

return dtype

@model_validator(mode="after")
def validate_vector(self) -> Self:
"""If the vector passed in is an array of float convert it to a byte string."""
if isinstance(self.vector, bytes):
return self
self.vector = array_to_buffer(self.vector, self.dtype)
return self


class AggregationQuery(AggregateRequest):
"""
Expand Down Expand Up @@ -364,12 +372,8 @@ def params(self) -> Dict[str, Any]:
Dict[str, Any]: The parameters for the aggregation.
"""
params = {}
for i, (vector, dtype) in enumerate(
[(v.vector, v.dtype) for v in self._vectors]
):
if isinstance(vector, list):
vector = array_to_buffer(vector, dtype=dtype) # type: ignore
params[f"vector_{i}"] = vector
for i, v in enumerate(self._vectors):
params[f"vector_{i}"] = v.vector
return params

def _build_query_string(self) -> str:
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,32 @@ def test_multivector_query(index):
)


def test_multivector_query_accepts_bytes(index):
skip_if_redis_version_below(index.client, "7.2.0")

vector_bytes = [
array_to_buffer([0.1, 0.1, 0.5], "float32"),
array_to_buffer([0.3, 0.4, 0.7, 0.2, -0.3, 0.25], "float64"),
]
vector_fields = ["user_embedding", "audio_embedding"]
dtypes = ["float32", "float64"]
dtypes = ["float16", "float16"]
vectors = []
for vector, field, dtype in zip(vector_bytes, vector_fields, dtypes):
vectors.append(Vector(vector=vector, field_name=field, dtype=dtype))

return_fields = ["user", "credit_score", "age", "job", "location", "description"]

multi_query = MultiVectorQuery(
vectors=vectors,
return_fields=return_fields,
)

results = index.query(multi_query)
assert isinstance(results, list)
assert len(results) == 7


def test_multivector_query_with_filter(index):
skip_if_redis_version_below(index.client, "7.2.0")

Expand Down
18 changes: 18 additions & 0 deletions tests/unit/test_aggregation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from redisvl.index.index import process_results
from redisvl.query.aggregate import HybridQuery, MultiVectorQuery, Vector
from redisvl.query.filter import Tag
from redisvl.redis.utils import array_to_buffer

# Sample data for testing
sample_vector = [0.1, 0.2, 0.3, 0.4]
Expand Down Expand Up @@ -314,3 +315,20 @@ def test_vector_object_validation():
for dtype in ["bfloat16", "float16", "float32", "float64", "int8", "uint8"]:
vec = Vector(vector=sample_vector, field_name="text embedding", dtype=dtype)
assert isinstance(vec, Vector)


def test_vector_object_handles_byte_conversion():
# test that passing an array of floats gets converted to bytes
vec = Vector(vector=sample_vector, field_name="field 1", dtype="float16")
assert vec.vector == array_to_buffer(sample_vector, dtype="float16")

# test we can pass an array of floats and convert to all supported dtypes
for datatype in ["bfloat16", "float16", "float32", "float64"]:
vec = Vector(vector=sample_vector, field_name="field 1", dtype=datatype)
assert vec.vector == array_to_buffer(sample_vector, dtype=datatype)

# test that passing in a byte string it is stored unchanged
for datatype in ["bfloat16", "float16", "float32", "float64"]:
byte_string = array_to_buffer(sample_vector, datatype)
vec = Vector(vector=byte_string, field_name="field 1")
assert vec.vector == byte_string