44from redis .commands .search .result import Result
55
66from redisvl .index import SearchIndex
7- from redisvl .query import CountQuery , FilterQuery , RangeQuery , VectorQuery
7+ from redisvl .query import CountQuery , FilterQuery , VectorQuery , VectorRangeQuery
88from redisvl .query .filter import (
99 FilterExpression ,
1010 Geo ,
@@ -105,7 +105,7 @@ def sorted_filter_query():
105105
106106@pytest .fixture
107107def normalized_range_query ():
108- return RangeQuery (
108+ return VectorRangeQuery (
109109 vector = [0.1 , 0.1 , 0.5 ],
110110 vector_field_name = "user_embedding" ,
111111 normalize_vector_distance = True ,
@@ -117,7 +117,7 @@ def normalized_range_query():
117117
118118@pytest .fixture
119119def range_query ():
120- return RangeQuery (
120+ return VectorRangeQuery (
121121 vector = [0.1 , 0.1 , 0.5 ],
122122 vector_field_name = "user_embedding" ,
123123 return_fields = ["user" , "credit_score" , "age" , "job" , "location" ],
@@ -127,7 +127,7 @@ def range_query():
127127
128128@pytest .fixture
129129def sorted_range_query ():
130- return RangeQuery (
130+ return VectorRangeQuery (
131131 vector = [0.1 , 0.1 , 0.5 ],
132132 vector_field_name = "user_embedding" ,
133133 return_fields = ["user" , "credit_score" , "age" , "job" , "location" ],
@@ -272,7 +272,7 @@ def test_search_and_query(index):
272272
273273
274274def test_range_query (index ):
275- r = RangeQuery (
275+ r = VectorRangeQuery (
276276 vector = [0.1 , 0.1 , 0.5 ],
277277 vector_field_name = "user_embedding" ,
278278 return_fields = ["user" , "credit_score" , "age" , "job" ],
@@ -343,7 +343,7 @@ def search(
343343 assert doc .location == location
344344
345345 # if range query, test results by distance threshold
346- if isinstance (query , RangeQuery ):
346+ if isinstance (query , VectorRangeQuery ):
347347 for doc in results .docs :
348348 print (doc .vector_distance )
349349 assert float (doc .vector_distance ) <= distance_threshold
@@ -354,7 +354,7 @@ def search(
354354
355355 # check results are in sorted order
356356 if sort :
357- if isinstance (query , RangeQuery ):
357+ if isinstance (query , VectorRangeQuery ):
358358 assert [int (doc .age ) for doc in results .docs ] == [12 , 14 , 18 , 100 ]
359359 else :
360360 assert [int (doc .age ) for doc in results .docs ] == [
@@ -370,7 +370,7 @@ def search(
370370
371371@pytest .fixture (
372372 params = ["vector_query" , "filter_query" , "range_query" ],
373- ids = ["VectorQuery" , "FilterQuery" , "RangeQuery " ],
373+ ids = ["VectorQuery" , "FilterQuery" , "VectorRangeQuery " ],
374374)
375375def query (request ):
376376 return request .getfixturevalue (request .param )
@@ -778,3 +778,15 @@ def test_range_query_normalize_cosine_distance(index, normalized_range_query):
778778
779779 for r in res :
780780 assert 0 <= float (r ["vector_distance" ]) <= 1
781+
782+
783+ def test_range_query_normalize_bad_input (index ):
784+ with pytest .raises (ValueError ):
785+ VectorRangeQuery (
786+ vector = [0.1 , 0.1 , 0.5 ],
787+ vector_field_name = "user_embedding" ,
788+ normalize_vector_distance = True ,
789+ return_score = True ,
790+ return_fields = ["user" , "credit_score" , "age" , "job" , "location" ],
791+ distance_threshold = 1.2 ,
792+ )
0 commit comments