1
1
from enum import Enum
2
2
from typing import Any , Dict , List , Optional , Union
3
3
4
+ from redis .commands .search .aggregation import AggregateRequest , Desc
4
5
from redis .commands .search .query import Query as RedisQuery
5
6
6
7
from redisvl .query .filter import FilterExpression
@@ -137,7 +138,7 @@ def __init__(
137
138
"""A query for a simple count operation provided some filter expression.
138
139
139
140
Args:
140
- filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
141
+ filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
141
142
query with. Defaults to None.
142
143
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
143
144
@@ -654,31 +655,32 @@ class RangeQuery(VectorRangeQuery):
654
655
655
656
class TextQuery (FilterQuery ):
656
657
def __init__ (
657
- self ,
658
+ self ,
658
659
text : str ,
659
660
text_field : str ,
660
- text_scorer : str = "TFIDF" ,
661
- return_fields : Optional [List [str ]] = None ,
661
+ text_scorer : str = "BM25" ,
662
662
filter_expression : Optional [Union [str , FilterExpression ]] = None ,
663
+ return_fields : Optional [List [str ]] = None ,
663
664
num_results : int = 10 ,
664
665
return_score : bool = True ,
665
666
dialect : int = 2 ,
666
667
sort_by : Optional [str ] = None ,
667
668
in_order : bool = False ,
669
+ params : Optional [Dict [str , Any ]] = None ,
668
670
):
669
671
"""A query for running a full text and vector search, along with an optional
670
672
filter expression.
671
673
672
674
Args:
673
- text (str): The text string to perform the text search with.
675
+ text (str): The text string to perform the text search with.
674
676
text_field (str): The name of the document field to perform text search on.
675
677
text_scorer (str, optional): The text scoring algorithm to use.
676
- Defaults to TFIDF . Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
678
+ Defaults to BM25 . Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
677
679
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
680
+ filter_expression (Union[str, FilterExpression], optional): A filter to apply
681
+ along with the text search. Defaults to None.
678
682
return_fields (List[str]): The declared fields to return with search
679
683
results.
680
- filter_expression (Union[str, FilterExpression], optional): A filter to apply
681
- along with the vector search. Defaults to None.
682
684
num_results (int, optional): The top k results to return from the
683
685
search. Defaults to 10.
684
686
return_score (bool, optional): Whether to return the text score.
@@ -690,174 +692,82 @@ def __init__(
690
692
in_order (bool): Requires the terms in the field to have
691
693
the same order as the terms in the query filter, regardless of
692
694
the offsets between them. Defaults to False.
693
-
694
- Raises:
695
- TypeError: If filter_expression is not of type redisvl.query.FilterExpression
695
+ params (Optional[Dict[str, Any]], optional): The parameters for the query.
696
+ Defaults to None.
696
697
"""
698
+ import nltk
699
+ from nltk .corpus import stopwords
700
+
701
+ nltk .download ("stopwords" )
702
+ self ._stopwords = set (stopwords .words ("english" ))
703
+
704
+ self ._text = text
697
705
self ._text_field = text_field
698
- self ._num_results = num_results
706
+ self ._text_scorer = text_scorer
707
+
699
708
self .set_filter (filter_expression )
700
- query_string = self ._build_query_string ()
701
- from nltk .corpus import stopwords
702
- import nltk
709
+ self ._num_results = num_results
703
710
704
- nltk .download ('stopwords' )
705
- self ._stopwords = set (stopwords .words ('english' ))
711
+ query_string = self ._build_query_string ()
706
712
707
- super ().__init__ (query_string )
713
+ super ().__init__ (
714
+ query_string ,
715
+ return_fields = return_fields ,
716
+ num_results = num_results ,
717
+ dialect = dialect ,
718
+ sort_by = sort_by ,
719
+ in_order = in_order ,
720
+ params = params ,
721
+ )
708
722
709
723
# Handle query modifiers
710
- if return_fields :
711
- self .return_fields (* return_fields )
712
-
724
+ self .scorer (self ._text_scorer )
713
725
self .paging (0 , self ._num_results ).dialect (dialect )
714
726
715
727
if return_score :
716
- self .return_fields (self .DISTANCE_ID ) #TODO
717
-
718
- if sort_by :
719
- self .sort_by (sort_by )
720
- else :
721
- self .sort_by (self .DISTANCE_ID ) #TODO
728
+ self .with_scores ()
722
729
723
- if in_order :
724
- self .in_order ()
725
-
726
-
727
- def _tokenize_query (self , user_query : str ) -> str :
730
+ def tokenize_and_escape_query (self , user_query : str ) -> str :
728
731
"""Convert a raw user query to a redis full text query joined by ORs"""
732
+ from redisvl .utils .token_escaper import TokenEscaper
729
733
730
- words = word_tokenize (user_query )
731
-
732
- tokens = [token .strip ().strip ("," ).lower () for token in user_query .split ()]
733
- return " | " .join ([token for token in tokens if token not in self ._stopwords ])
734
+ escaper = TokenEscaper ()
734
735
736
+ tokens = [
737
+ escaper .escape (
738
+ token .strip ().strip ("," ).replace ("“" , "" ).replace ("”" , "" ).lower ()
739
+ )
740
+ for token in user_query .split ()
741
+ ]
742
+ return " | " .join (
743
+ [token for token in tokens if token and token not in self ._stopwords ]
744
+ )
735
745
736
746
def _build_query_string (self ) -> str :
737
747
"""Build the full query string for text search with optional filtering."""
738
748
filter_expression = self ._filter_expression
739
- # TODO include text only
740
749
if isinstance (filter_expression , FilterExpression ):
741
750
filter_expression = str (filter_expression )
751
+ else :
752
+ filter_expression = ""
742
753
743
- text = f"(~{ Text (self ._text_field ) % self ._tokenize_query (user_query )} )"
744
-
745
- text_and_filter = text & self ._filter_expression
746
-
747
- #TODO is this method even needed? use
748
- return text_and_filter
754
+ text = f"(~@{ self ._text_field } :({ self .tokenize_and_escape_query (self ._text )} ))"
755
+ if filter_expression and filter_expression != "*" :
756
+ text += f"({ filter_expression } )"
757
+ return text
749
758
750
- # from redisvl.utils.token_escaper import TokenEscaper
751
- # escaper = TokenEscaper()
752
- # def tokenize_and_escape_query(user_query: str) -> str:
753
- # """Convert a raw user query to a redis full text query joined by ORs"""
754
- # tokens = [escaper.escape(token.strip().strip(",").replace("“", "").replace("”", "").lower()) for token in user_query.split()]
755
- # return " | ".join([token for token in tokens if token and token not in stopwords_en])
756
759
757
- class HybridQuery (VectorQuery , TextQuery ):
758
- def __init__ ():
759
- self ,
760
- text : str ,
761
- text_field : str ,
762
- vector : Union [List [float ], bytes ],
763
- vector_field_name : str ,
764
- text_scorer : str = "TFIDF" ,
765
- alpha : float = 0.7 ,
766
- return_fields : Optional [List [str ]] = None ,
767
- filter_expression : Optional [Union [str , FilterExpression ]] = None ,
768
- dtype : str = "float32" ,
769
- num_results : int = 10 ,
770
- return_score : bool = True ,
771
- dialect : int = 2 ,
772
- sort_by : Optional [str ] = None ,
773
- in_order : bool = False ,
760
+ class HybridQuery (AggregateRequest ):
761
+ def __init__ (
762
+ self , text_query : TextQuery , vector_query : VectorQuery , alpha : float = 0.7
774
763
):
775
- """A query for running a hybrid full text and vector search, along with
776
- an optional filter expression.
764
+ """An aggregate query for running a hybrid full text and vector search.
777
765
778
766
Args:
779
- text (str): The text string to run text search with.
780
- text_field (str): The name of the text field to search against.
781
- vector (List[float]): The vector to perform the vector search with.
782
- vector_field_name (str): The name of the vector field to search
783
- against in the database.
784
- text_scorer (str, optional): The text scoring algorithm to use.
785
- Defaults to TFIDF.
767
+ text_query (TextQuery): The text query to run text search with.
768
+ vector_query (VectorQuery): The vector query to run vector search with.
786
769
alpha (float, optional): The amount to weight the vector similarity
787
770
score relative to the text similarity score. Defaults to 0.7
788
- return_fields (List[str]): The declared fields to return with search
789
- results.
790
- filter_expression (Union[str, FilterExpression], optional): A filter to apply
791
- along with the vector search. Defaults to None.
792
- dtype (str, optional): The dtype of the vector. Defaults to
793
- "float32".
794
- num_results (int, optional): The top k results to return from the
795
- vector search. Defaults to 10.
796
- return_score (bool, optional): Whether to return the vector
797
- distance. Defaults to True.
798
- dialect (int, optional): The RediSearch query dialect.
799
- Defaults to 2.
800
- sort_by (Optional[str]): The field to order the results by. Defaults
801
- to None. Results will be ordered by vector distance.
802
- in_order (bool): Requires the terms in the field to have
803
- the same order as the terms in the query filter, regardless of
804
- the offsets between them. Defaults to False.
805
-
806
- Raises:
807
- TypeError: If filter_expression is not of type redisvl.query.FilterExpression
808
-
809
- Note:
810
- Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
811
- """
812
- self ._text = text
813
- self ._text_field_name = tex_field_name
814
- self ._vector = vector
815
- self ._vector_field_name = vector_field_name
816
- self ._dtype = dtype
817
- self ._num_results = num_results
818
- self .set_filter (filter_expression )
819
- query_string = self ._build_query_string ()
820
-
821
- # TODO how to handle multiple parents? call parent.__init__() manually?
822
- super ().__init__ (query_string )
823
-
824
- # Handle query modifiers
825
- if return_fields :
826
- self .return_fields (* return_fields )
827
771
828
- self .paging (0 , self ._num_results ).dialect (dialect )
829
-
830
- if return_score :
831
- self .return_fields (self .DISTANCE_ID )
832
-
833
- if sort_by :
834
- self .sort_by (sort_by )
835
- else :
836
- self .sort_by (self .DISTANCE_ID )
837
-
838
- if in_order :
839
- self .in_order ()
840
-
841
-
842
- def _build_query_string (self ) -> str :
843
- """Build the full query string for hybrid search with optional filtering."""
844
- filter_expression = self ._filter_expression
845
- # TODO include hybrid
846
- if isinstance (filter_expression , FilterExpression ):
847
- filter_expression = str (filter_expression )
848
- return f"{ filter_expression } =>[KNN { self ._num_results } @{ self ._vector_field_name } ${ self .VECTOR_PARAM } AS { self .DISTANCE_ID } ]"
849
-
850
- @property
851
- def params (self ) -> Dict [str , Any ]:
852
- """Return the parameters for the query.
853
-
854
- Returns:
855
- Dict[str, Any]: The parameters for the query.
856
772
"""
857
- if isinstance (self ._vector , bytes ):
858
- vector = self ._vector
859
- else :
860
- vector = array_to_buffer (self ._vector , dtype = self ._dtype )
861
-
862
- return {self .VECTOR_PARAM : vector }
863
-
773
+ pass
0 commit comments