Skip to content

Commit 3b8e2b6

Browse files
justin-cechmanektylerhutcherson
authored andcommitted
adds TextQuery class
1 parent 3fa93ff commit 3b8e2b6

File tree

3 files changed

+134
-150
lines changed

3 files changed

+134
-150
lines changed

redisvl/query/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
BaseQuery,
33
CountQuery,
44
FilterQuery,
5+
HybridQuery,
56
RangeQuery,
7+
TextQuery,
68
VectorQuery,
79
VectorRangeQuery,
810
)
@@ -14,4 +16,6 @@
1416
"RangeQuery",
1517
"VectorRangeQuery",
1618
"CountQuery",
19+
"TextQuery",
20+
"HybridQuery",
1721
]

redisvl/query/query.py

Lines changed: 59 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from enum import Enum
22
from typing import Any, Dict, List, Optional, Union
33

4+
from redis.commands.search.aggregation import AggregateRequest, Desc
45
from redis.commands.search.query import Query as RedisQuery
56

67
from redisvl.query.filter import FilterExpression
@@ -137,7 +138,7 @@ def __init__(
137138
"""A query for a simple count operation provided some filter expression.
138139
139140
Args:
140-
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
141+
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
141142
query with. Defaults to None.
142143
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
143144
@@ -654,31 +655,32 @@ class RangeQuery(VectorRangeQuery):
654655

655656
class TextQuery(FilterQuery):
656657
def __init__(
657-
self,
658+
self,
658659
text: str,
659660
text_field: str,
660-
text_scorer: str = "TFIDF",
661-
return_fields: Optional[List[str]] = None,
661+
text_scorer: str = "BM25",
662662
filter_expression: Optional[Union[str, FilterExpression]] = None,
663+
return_fields: Optional[List[str]] = None,
663664
num_results: int = 10,
664665
return_score: bool = True,
665666
dialect: int = 2,
666667
sort_by: Optional[str] = None,
667668
in_order: bool = False,
669+
params: Optional[Dict[str, Any]] = None,
668670
):
669671
"""A query for running a full text and vector search, along with an optional
670672
filter expression.
671673
672674
Args:
673-
text (str): The text string to perform the text search with.
675+
text (str): The text string to perform the text search with.
674676
text_field (str): The name of the document field to perform text search on.
675677
text_scorer (str, optional): The text scoring algorithm to use.
676-
Defaults to TFIDF. Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
678+
Defaults to BM25. Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
677679
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
680+
filter_expression (Union[str, FilterExpression], optional): A filter to apply
681+
along with the text search. Defaults to None.
678682
return_fields (List[str]): The declared fields to return with search
679683
results.
680-
filter_expression (Union[str, FilterExpression], optional): A filter to apply
681-
along with the vector search. Defaults to None.
682684
num_results (int, optional): The top k results to return from the
683685
search. Defaults to 10.
684686
return_score (bool, optional): Whether to return the text score.
@@ -690,174 +692,82 @@ def __init__(
690692
in_order (bool): Requires the terms in the field to have
691693
the same order as the terms in the query filter, regardless of
692694
the offsets between them. Defaults to False.
693-
694-
Raises:
695-
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
695+
params (Optional[Dict[str, Any]], optional): The parameters for the query.
696+
Defaults to None.
696697
"""
698+
import nltk
699+
from nltk.corpus import stopwords
700+
701+
nltk.download("stopwords")
702+
self._stopwords = set(stopwords.words("english"))
703+
704+
self._text = text
697705
self._text_field = text_field
698-
self._num_results = num_results
706+
self._text_scorer = text_scorer
707+
699708
self.set_filter(filter_expression)
700-
query_string = self._build_query_string()
701-
from nltk.corpus import stopwords
702-
import nltk
709+
self._num_results = num_results
703710

704-
nltk.download('stopwords')
705-
self._stopwords = set(stopwords.words('english'))
711+
query_string = self._build_query_string()
706712

707-
super().__init__(query_string)
713+
super().__init__(
714+
query_string,
715+
return_fields=return_fields,
716+
num_results=num_results,
717+
dialect=dialect,
718+
sort_by=sort_by,
719+
in_order=in_order,
720+
params=params,
721+
)
708722

709723
# Handle query modifiers
710-
if return_fields:
711-
self.return_fields(*return_fields)
712-
724+
self.scorer(self._text_scorer)
713725
self.paging(0, self._num_results).dialect(dialect)
714726

715727
if return_score:
716-
self.return_fields(self.DISTANCE_ID) #TODO
717-
718-
if sort_by:
719-
self.sort_by(sort_by)
720-
else:
721-
self.sort_by(self.DISTANCE_ID) #TODO
728+
self.with_scores()
722729

723-
if in_order:
724-
self.in_order()
725-
726-
727-
def _tokenize_query(self, user_query: str) -> str:
730+
def tokenize_and_escape_query(self, user_query: str) -> str:
728731
"""Convert a raw user query to a redis full text query joined by ORs"""
732+
from redisvl.utils.token_escaper import TokenEscaper
729733

730-
words = word_tokenize(user_query)
731-
732-
tokens = [token.strip().strip(",").lower() for token in user_query.split()]
733-
return " | ".join([token for token in tokens if token not in self._stopwords])
734+
escaper = TokenEscaper()
734735

736+
tokens = [
737+
escaper.escape(
738+
token.strip().strip(",").replace("“", "").replace("”", "").lower()
739+
)
740+
for token in user_query.split()
741+
]
742+
return " | ".join(
743+
[token for token in tokens if token and token not in self._stopwords]
744+
)
735745

736746
def _build_query_string(self) -> str:
737747
"""Build the full query string for text search with optional filtering."""
738748
filter_expression = self._filter_expression
739-
# TODO include text only
740749
if isinstance(filter_expression, FilterExpression):
741750
filter_expression = str(filter_expression)
751+
else:
752+
filter_expression = ""
742753

743-
text = f"(~{Text(self._text_field) % self._tokenize_query(user_query)})"
744-
745-
text_and_filter = text & self._filter_expression
746-
747-
#TODO is this method even needed? use
748-
return text_and_filter
754+
text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))"
755+
if filter_expression and filter_expression != "*":
756+
text += f"({filter_expression})"
757+
return text
749758

750-
# from redisvl.utils.token_escaper import TokenEscaper
751-
# escaper = TokenEscaper()
752-
# def tokenize_and_escape_query(user_query: str) -> str:
753-
# """Convert a raw user query to a redis full text query joined by ORs"""
754-
# tokens = [escaper.escape(token.strip().strip(",").replace("“", "").replace("”", "").lower()) for token in user_query.split()]
755-
# return " | ".join([token for token in tokens if token and token not in stopwords_en])
756759

757-
class HybridQuery(VectorQuery, TextQuery):
758-
def __init__():
759-
self,
760-
text: str,
761-
text_field: str,
762-
vector: Union[List[float], bytes],
763-
vector_field_name: str,
764-
text_scorer: str = "TFIDF",
765-
alpha: float = 0.7,
766-
return_fields: Optional[List[str]] = None,
767-
filter_expression: Optional[Union[str, FilterExpression]] = None,
768-
dtype: str = "float32",
769-
num_results: int = 10,
770-
return_score: bool = True,
771-
dialect: int = 2,
772-
sort_by: Optional[str] = None,
773-
in_order: bool = False,
760+
class HybridQuery(AggregateRequest):
761+
def __init__(
762+
self, text_query: TextQuery, vector_query: VectorQuery, alpha: float = 0.7
774763
):
775-
"""A query for running a hybrid full text and vector search, along with
776-
an optional filter expression.
764+
"""An aggregate query for running a hybrid full text and vector search.
777765
778766
Args:
779-
text (str): The text string to run text search with.
780-
text_field (str): The name of the text field to search against.
781-
vector (List[float]): The vector to perform the vector search with.
782-
vector_field_name (str): The name of the vector field to search
783-
against in the database.
784-
text_scorer (str, optional): The text scoring algorithm to use.
785-
Defaults to TFIDF.
767+
text_query (TextQuery): The text query to run text search with.
768+
vector_query (VectorQuery): The vector query to run vector search with.
786769
alpha (float, optional): The amount to weight the vector similarity
787770
score relative to the text similarity score. Defaults to 0.7
788-
return_fields (List[str]): The declared fields to return with search
789-
results.
790-
filter_expression (Union[str, FilterExpression], optional): A filter to apply
791-
along with the vector search. Defaults to None.
792-
dtype (str, optional): The dtype of the vector. Defaults to
793-
"float32".
794-
num_results (int, optional): The top k results to return from the
795-
vector search. Defaults to 10.
796-
return_score (bool, optional): Whether to return the vector
797-
distance. Defaults to True.
798-
dialect (int, optional): The RediSearch query dialect.
799-
Defaults to 2.
800-
sort_by (Optional[str]): The field to order the results by. Defaults
801-
to None. Results will be ordered by vector distance.
802-
in_order (bool): Requires the terms in the field to have
803-
the same order as the terms in the query filter, regardless of
804-
the offsets between them. Defaults to False.
805-
806-
Raises:
807-
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
808-
809-
Note:
810-
Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
811-
"""
812-
self._text = text
813-
self._text_field_name = tex_field_name
814-
self._vector = vector
815-
self._vector_field_name = vector_field_name
816-
self._dtype = dtype
817-
self._num_results = num_results
818-
self.set_filter(filter_expression)
819-
query_string = self._build_query_string()
820-
821-
# TODO how to handle multiple parents? call parent.__init__() manually?
822-
super().__init__(query_string)
823-
824-
# Handle query modifiers
825-
if return_fields:
826-
self.return_fields(*return_fields)
827771
828-
self.paging(0, self._num_results).dialect(dialect)
829-
830-
if return_score:
831-
self.return_fields(self.DISTANCE_ID)
832-
833-
if sort_by:
834-
self.sort_by(sort_by)
835-
else:
836-
self.sort_by(self.DISTANCE_ID)
837-
838-
if in_order:
839-
self.in_order()
840-
841-
842-
def _build_query_string(self) -> str:
843-
"""Build the full query string for hybrid search with optional filtering."""
844-
filter_expression = self._filter_expression
845-
# TODO include hybrid
846-
if isinstance(filter_expression, FilterExpression):
847-
filter_expression = str(filter_expression)
848-
return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]"
849-
850-
@property
851-
def params(self) -> Dict[str, Any]:
852-
"""Return the parameters for the query.
853-
854-
Returns:
855-
Dict[str, Any]: The parameters for the query.
856772
"""
857-
if isinstance(self._vector, bytes):
858-
vector = self._vector
859-
else:
860-
vector = array_to_buffer(self._vector, dtype=self._dtype)
861-
862-
return {self.VECTOR_PARAM: vector}
863-
773+
pass

tests/unit/test_query_types.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
from redis.commands.search.result import Result
44

55
from redisvl.index.index import process_results
6-
from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
6+
from redisvl.query import (
7+
CountQuery,
8+
FilterQuery,
9+
HybridQuery,
10+
RangeQuery,
11+
TextQuery,
12+
VectorQuery,
13+
)
714
from redisvl.query.filter import Tag
815
from redisvl.query.query import VectorRangeQuery
916

@@ -188,6 +195,69 @@ def test_range_query():
188195
assert range_query._in_order
189196

190197

198+
def test_text_query():
199+
text_string = "the toon squad play basketball against a gang of aliens"
200+
text_field_name = "description"
201+
return_fields = ["title", "genre", "rating"]
202+
text_query = TextQuery(
203+
text=text_string,
204+
text_field=text_field_name,
205+
return_fields=return_fields,
206+
return_score=False,
207+
)
208+
209+
# Check properties
210+
assert text_query._return_fields == return_fields
211+
assert text_query._num_results == 10
212+
assert (
213+
text_query.filter
214+
== f"(~@{text_field_name}:({text_query.tokenize_and_escape_query(text_string)}))"
215+
)
216+
assert isinstance(text_query, Query)
217+
assert isinstance(text_query.query, Query)
218+
assert isinstance(text_query.params, dict)
219+
assert text_query._text_scorer == "BM25"
220+
assert text_query.params == {}
221+
assert text_query._dialect == 2
222+
assert text_query._in_order == False
223+
224+
# Test paging functionality
225+
text_query.paging(5, 7)
226+
assert text_query._offset == 5
227+
assert text_query._num == 7
228+
assert text_query._num_results == 10
229+
230+
# Test sort_by functionality
231+
filter_expression = Tag("genre") == "comedy"
232+
scorer = "TFIDF"
233+
text_query = TextQuery(
234+
text_string,
235+
text_field_name,
236+
scorer,
237+
filter_expression,
238+
return_fields,
239+
num_results=10,
240+
sort_by="rating",
241+
)
242+
assert text_query._sortby is not None
243+
244+
# Test in_order functionality
245+
text_query = TextQuery(
246+
text_string,
247+
text_field_name,
248+
scorer,
249+
filter_expression,
250+
return_fields,
251+
num_results=10,
252+
in_order=True,
253+
)
254+
assert text_query._in_order
255+
256+
257+
def test_hybrid_query():
258+
pass
259+
260+
191261
@pytest.mark.parametrize(
192262
"query",
193263
[

0 commit comments

Comments
 (0)