|  | 
| 1 | 1 | from typing import Any, Dict, List, Optional, Set, Tuple, Union | 
| 2 | 2 | 
 | 
|  | 3 | +from pydantic import BaseModel, field_validator | 
| 3 | 4 | from redis.commands.search.aggregation import AggregateRequest, Desc | 
| 4 | 5 | 
 | 
| 5 | 6 | from redisvl.query.filter import FilterExpression | 
| 6 | 7 | from redisvl.redis.utils import array_to_buffer | 
|  | 8 | +from redisvl.schema.fields import VectorDataType | 
| 7 | 9 | from redisvl.utils.token_escaper import TokenEscaper | 
| 8 | 10 | from redisvl.utils.utils import lazy_import | 
| 9 | 11 | 
 | 
| 10 | 12 | nltk = lazy_import("nltk") | 
| 11 | 13 | nltk_stopwords = lazy_import("nltk.corpus.stopwords") | 
| 12 | 14 | 
 | 
| 13 | 15 | 
 | 
|  | 16 | +class Vector(BaseModel): | 
|  | 17 | +    """ | 
|  | 18 | +    Simple object containing the necessary arguments to perform a multi vector query. | 
|  | 19 | +    """ | 
|  | 20 | + | 
|  | 21 | +    vector: Union[List[float], bytes] | 
|  | 22 | +    field_name: str | 
|  | 23 | +    dtype: str = "float32" | 
|  | 24 | +    weight: float = 1.0 | 
|  | 25 | + | 
|  | 26 | +    @field_validator("dtype") | 
|  | 27 | +    @classmethod | 
|  | 28 | +    def validate_dtype(cls, dtype: str) -> str: | 
|  | 29 | +        try: | 
|  | 30 | +            VectorDataType(dtype.upper()) | 
|  | 31 | +        except ValueError: | 
|  | 32 | +            raise ValueError( | 
|  | 33 | +                f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" | 
|  | 34 | +            ) | 
|  | 35 | + | 
|  | 36 | +        return dtype | 
|  | 37 | + | 
|  | 38 | + | 
| 14 | 39 | class AggregationQuery(AggregateRequest): | 
| 15 | 40 |     """ | 
| 16 | 41 |     Base class for aggregation queries used to create aggregation queries for Redis. | 
| @@ -227,3 +252,149 @@ def _build_query_string(self) -> str: | 
| 227 | 252 |     def __str__(self) -> str: | 
| 228 | 253 |         """Return the string representation of the query.""" | 
| 229 | 254 |         return " ".join([str(x) for x in self.build_args()]) | 
|  | 255 | + | 
|  | 256 | + | 
|  | 257 | +class MultiVectorQuery(AggregationQuery): | 
|  | 258 | +    """ | 
|  | 259 | +    MultiVectorQuery allows for search over multiple vector fields in a document simulateously. | 
|  | 260 | +    The final score will be a weighted combination of the individual vector similarity scores | 
|  | 261 | +    following the formula: | 
|  | 262 | +
 | 
|  | 263 | +    score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... ) | 
|  | 264 | +
 | 
|  | 265 | +    Vectors may be of different size and datatype, but must be indexed using the 'cosine' distance_metric. | 
|  | 266 | +
 | 
|  | 267 | +    .. code-block:: python | 
|  | 268 | +
 | 
|  | 269 | +        from redisvl.query import MultiVectorQuery, Vector | 
|  | 270 | +        from redisvl.index import SearchIndex | 
|  | 271 | +
 | 
|  | 272 | +        index = SearchIndex.from_yaml("path/to/index.yaml") | 
|  | 273 | +
 | 
|  | 274 | +        vector_1 = Vector( | 
|  | 275 | +            vector=[0.1, 0.2, 0.3], | 
|  | 276 | +            field_name="text_vector", | 
|  | 277 | +            dtype="float32", | 
|  | 278 | +            weight=0.7, | 
|  | 279 | +        ) | 
|  | 280 | +        vector_2 = Vector( | 
|  | 281 | +            vector=[0.5, 0.5], | 
|  | 282 | +            field_name="image_vector", | 
|  | 283 | +            dtype="bfloat16", | 
|  | 284 | +            weight=0.2, | 
|  | 285 | +        ) | 
|  | 286 | +        vector_3 = Vector( | 
|  | 287 | +            vector=[0.1, 0.2, 0.3], | 
|  | 288 | +            field_name="text_vector", | 
|  | 289 | +            dtype="float64", | 
|  | 290 | +            weight=0.5, | 
|  | 291 | +        ) | 
|  | 292 | +
 | 
|  | 293 | +        query = MultiVectorQuery( | 
|  | 294 | +            vectors=[vector_1, vector_2, vector_3], | 
|  | 295 | +            filter_expression=None, | 
|  | 296 | +            num_results=10, | 
|  | 297 | +            return_fields=["field1", "field2"], | 
|  | 298 | +            dialect=2, | 
|  | 299 | +        ) | 
|  | 300 | +
 | 
|  | 301 | +        results = index.query(query) | 
|  | 302 | +    """ | 
|  | 303 | + | 
|  | 304 | +    _vectors: List[Vector] | 
|  | 305 | + | 
|  | 306 | +    def __init__( | 
|  | 307 | +        self, | 
|  | 308 | +        vectors: Union[Vector, List[Vector]], | 
|  | 309 | +        return_fields: Optional[List[str]] = None, | 
|  | 310 | +        filter_expression: Optional[Union[str, FilterExpression]] = None, | 
|  | 311 | +        num_results: int = 10, | 
|  | 312 | +        dialect: int = 2, | 
|  | 313 | +    ): | 
|  | 314 | +        """ | 
|  | 315 | +        Instantiates a MultiVectorQuery object. | 
|  | 316 | +
 | 
|  | 317 | +        Args: | 
|  | 318 | +            vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search. | 
|  | 319 | +            return_fields (Optional[List[str]], optional): The fields to return. Defaults to None. | 
|  | 320 | +            filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use. | 
|  | 321 | +                Defaults to None. | 
|  | 322 | +            num_results (int, optional): The number of results to return. Defaults to 10. | 
|  | 323 | +            dialect (int, optional): The Redis dialect version. Defaults to 2. | 
|  | 324 | +        """ | 
|  | 325 | + | 
|  | 326 | +        self._filter_expression = filter_expression | 
|  | 327 | +        self._num_results = num_results | 
|  | 328 | + | 
|  | 329 | +        if isinstance(vectors, Vector): | 
|  | 330 | +            self._vectors = [vectors] | 
|  | 331 | +        else: | 
|  | 332 | +            self._vectors = vectors  # type: ignore | 
|  | 333 | + | 
|  | 334 | +        if not all([isinstance(v, Vector) for v in self._vectors]): | 
|  | 335 | +            raise TypeError( | 
|  | 336 | +                "vector argument must be a Vector object or list of Vector objects." | 
|  | 337 | +            ) | 
|  | 338 | + | 
|  | 339 | +        query_string = self._build_query_string() | 
|  | 340 | +        super().__init__(query_string) | 
|  | 341 | + | 
|  | 342 | +        # calculate the respective vector similarities | 
|  | 343 | +        for i in range(len(self._vectors)): | 
|  | 344 | +            self.apply(**{f"score_{i}": f"(2 - @distance_{i})/2"}) | 
|  | 345 | + | 
|  | 346 | +        # construct the scoring string based on the vector similarity scores and weights | 
|  | 347 | +        combined_scores = [] | 
|  | 348 | +        for i, w in enumerate([v.weight for v in self._vectors]): | 
|  | 349 | +            combined_scores.append(f"@score_{i} * {w}") | 
|  | 350 | +        combined_score_string = " + ".join(combined_scores) | 
|  | 351 | + | 
|  | 352 | +        self.apply(combined_score=combined_score_string) | 
|  | 353 | + | 
|  | 354 | +        self.sort_by(Desc("@combined_score"), max=num_results)  # type: ignore | 
|  | 355 | +        self.dialect(dialect) | 
|  | 356 | +        if return_fields: | 
|  | 357 | +            self.load(*return_fields)  # type: ignore[arg-type] | 
|  | 358 | + | 
|  | 359 | +    @property | 
|  | 360 | +    def params(self) -> Dict[str, Any]: | 
|  | 361 | +        """Return the parameters for the aggregation. | 
|  | 362 | +
 | 
|  | 363 | +        Returns: | 
|  | 364 | +            Dict[str, Any]: The parameters for the aggregation. | 
|  | 365 | +        """ | 
|  | 366 | +        params = {} | 
|  | 367 | +        for i, (vector, dtype) in enumerate( | 
|  | 368 | +            [(v.vector, v.dtype) for v in self._vectors] | 
|  | 369 | +        ): | 
|  | 370 | +            if isinstance(vector, list): | 
|  | 371 | +                vector = array_to_buffer(vector, dtype=dtype)  # type: ignore | 
|  | 372 | +            params[f"vector_{i}"] = vector | 
|  | 373 | +        return params | 
|  | 374 | + | 
|  | 375 | +    def _build_query_string(self) -> str: | 
|  | 376 | +        """Build the full query string for text search with optional filtering.""" | 
|  | 377 | + | 
|  | 378 | +        # base KNN query | 
|  | 379 | +        range_queries = [] | 
|  | 380 | +        for i, (vector, field) in enumerate( | 
|  | 381 | +            [(v.vector, v.field_name) for v in self._vectors] | 
|  | 382 | +        ): | 
|  | 383 | +            range_queries.append( | 
|  | 384 | +                f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}" | 
|  | 385 | +            ) | 
|  | 386 | + | 
|  | 387 | +        range_query = " | ".join(range_queries) | 
|  | 388 | + | 
|  | 389 | +        filter_expression = self._filter_expression | 
|  | 390 | +        if isinstance(self._filter_expression, FilterExpression): | 
|  | 391 | +            filter_expression = str(self._filter_expression) | 
|  | 392 | + | 
|  | 393 | +        if filter_expression: | 
|  | 394 | +            return f"({range_query}) AND ({filter_expression})" | 
|  | 395 | +        else: | 
|  | 396 | +            return f"{range_query}" | 
|  | 397 | + | 
|  | 398 | +    def __str__(self) -> str: | 
|  | 399 | +        """Return the string representation of the query.""" | 
|  | 400 | +        return " ".join([str(x) for x in self.build_args()]) | 
0 commit comments