Skip to content

Commit 077cd8a

Browse files
committed
add tests edit bounds
1 parent 4453314 commit 077cd8a

File tree

7 files changed

+56
-27
lines changed

7 files changed

+56
-27
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,9 @@ def set_threshold(self, distance_threshold: float) -> None:
238238
Raises:
239239
ValueError: If the threshold is not between 0 and 1.
240240
"""
241-
if not 0 <= float(distance_threshold) <= 1:
241+
if not 0 <= float(distance_threshold) <= 2:
242242
raise ValueError(
243-
f"Distance must be between 0 and 1, got {distance_threshold}"
243+
f"Distance must be between 0 and 2, got {distance_threshold}"
244244
)
245245
self._distance_threshold = float(distance_threshold)
246246

redisvl/index/index.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@
1616
Union,
1717
)
1818

19-
from redisvl.utils.utils import (
20-
deprecated_argument,
21-
deprecated_function,
22-
norm_cosine_distance,
23-
sync_wrapper,
24-
)
19+
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
2520

2621
if TYPE_CHECKING:
2722
from redis.commands.search.aggregation import AggregateResult
@@ -35,13 +30,7 @@
3530

3631
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
3732
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
38-
from redisvl.query import (
39-
BaseQuery,
40-
CountQuery,
41-
FilterQuery,
42-
VectorQuery,
43-
VectorRangeQuery,
44-
)
33+
from redisvl.query import BaseQuery, BaseVectorQuery, CountQuery, FilterQuery
4534
from redisvl.query.filter import FilterExpression
4635
from redisvl.redis.connection import (
4736
RedisConnectionFactory,
@@ -92,9 +81,7 @@ def process_results(
9281
and not query._return_fields # type: ignore
9382
)
9483

95-
if (
96-
isinstance(query, VectorQuery) or isinstance(query, VectorRangeQuery)
97-
) and query._normalize_vector_distance:
84+
if (isinstance(query, BaseVectorQuery)) and query._normalize_vector_distance:
9885
dist_metric = VectorDistanceMetric(
9986
schema.fields[query._vector_field_name].attrs.distance_metric.upper() # type: ignore
10087
)

redisvl/query/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from redisvl.query.query import (
22
BaseQuery,
3+
BaseVectorQuery,
34
CountQuery,
45
FilterQuery,
56
RangeQuery,
@@ -9,6 +10,7 @@
910

1011
__all__ = [
1112
"BaseQuery",
13+
"BaseVectorQuery",
1214
"VectorQuery",
1315
"FilterQuery",
1416
"RangeQuery",

redisvl/query/query.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from redisvl.query.filter import FilterExpression
77
from redisvl.redis.utils import array_to_buffer
8+
from redisvl.utils.utils import denorm_cosine_distance
89

910

1011
class BaseQuery(RedisQuery):
@@ -175,6 +176,8 @@ class BaseVectorQuery:
175176
DISTANCE_ID: str = "vector_distance"
176177
VECTOR_PARAM: str = "vector"
177178

179+
_normalize_vector_distance: bool = False
180+
178181

179182
class HybridPolicy(str, Enum):
180183
"""Enum for valid hybrid policy options in vector queries."""
@@ -516,6 +519,14 @@ def set_distance_threshold(self, distance_threshold: float):
516519
raise TypeError("distance_threshold must be of type float or int")
517520
if distance_threshold < 0:
518521
raise ValueError("distance_threshold must be non-negative")
522+
if self._normalize_vector_distance:
523+
if distance_threshold > 1:
524+
raise ValueError(
525+
"distance_threshold must be between 0 and 1 when normalize_vector_distance is set to True"
526+
)
527+
528+
# User sets normalized value 0-1 denormalize for use in DB
529+
distance_threshold = denorm_cosine_distance(distance_threshold)
519530
self._distance_threshold = distance_threshold
520531

521532
# Reset the query string

redisvl/utils/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,12 @@ def norm_cosine_distance(value: float) -> float:
197197
"""
198198
Normalize the cosine distance to a similarity score between 0 and 1.
199199
"""
200-
return (2 - value) / 2
200+
return max((2 - value) / 2, 0)
201+
202+
203+
def denorm_cosine_distance(value: float) -> float:
204+
"""Denormalize the distance threshold from [0, 1] to [0, 1] for our db."""
205+
return max(2 - 2 * value, 0)
201206

202207

203208
def norm_l2_distance(value: float) -> float:

tests/integration/test_query.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from redis.commands.search.result import Result
55

66
from redisvl.index import SearchIndex
7-
from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
7+
from redisvl.query import CountQuery, FilterQuery, VectorQuery, VectorRangeQuery
88
from redisvl.query.filter import (
99
FilterExpression,
1010
Geo,
@@ -105,7 +105,7 @@ def sorted_filter_query():
105105

106106
@pytest.fixture
107107
def normalized_range_query():
108-
return RangeQuery(
108+
return VectorRangeQuery(
109109
vector=[0.1, 0.1, 0.5],
110110
vector_field_name="user_embedding",
111111
normalize_vector_distance=True,
@@ -117,7 +117,7 @@ def normalized_range_query():
117117

118118
@pytest.fixture
119119
def range_query():
120-
return RangeQuery(
120+
return VectorRangeQuery(
121121
vector=[0.1, 0.1, 0.5],
122122
vector_field_name="user_embedding",
123123
return_fields=["user", "credit_score", "age", "job", "location"],
@@ -127,7 +127,7 @@ def range_query():
127127

128128
@pytest.fixture
129129
def sorted_range_query():
130-
return RangeQuery(
130+
return VectorRangeQuery(
131131
vector=[0.1, 0.1, 0.5],
132132
vector_field_name="user_embedding",
133133
return_fields=["user", "credit_score", "age", "job", "location"],
@@ -272,7 +272,7 @@ def test_search_and_query(index):
272272

273273

274274
def test_range_query(index):
275-
r = RangeQuery(
275+
r = VectorRangeQuery(
276276
vector=[0.1, 0.1, 0.5],
277277
vector_field_name="user_embedding",
278278
return_fields=["user", "credit_score", "age", "job"],
@@ -343,7 +343,7 @@ def search(
343343
assert doc.location == location
344344

345345
# if range query, test results by distance threshold
346-
if isinstance(query, RangeQuery):
346+
if isinstance(query, VectorRangeQuery):
347347
for doc in results.docs:
348348
print(doc.vector_distance)
349349
assert float(doc.vector_distance) <= distance_threshold
@@ -354,7 +354,7 @@ def search(
354354

355355
# check results are in sorted order
356356
if sort:
357-
if isinstance(query, RangeQuery):
357+
if isinstance(query, VectorRangeQuery):
358358
assert [int(doc.age) for doc in results.docs] == [12, 14, 18, 100]
359359
else:
360360
assert [int(doc.age) for doc in results.docs] == [
@@ -370,7 +370,7 @@ def search(
370370

371371
@pytest.fixture(
372372
params=["vector_query", "filter_query", "range_query"],
373-
ids=["VectorQuery", "FilterQuery", "RangeQuery"],
373+
ids=["VectorQuery", "FilterQuery", "VectorRangeQuery"],
374374
)
375375
def query(request):
376376
return request.getfixturevalue(request.param)
@@ -778,3 +778,15 @@ def test_range_query_normalize_cosine_distance(index, normalized_range_query):
778778

779779
for r in res:
780780
assert 0 <= float(r["vector_distance"]) <= 1
781+
782+
783+
def test_range_query_normalize_bad_input(index):
784+
with pytest.raises(ValueError):
785+
VectorRangeQuery(
786+
vector=[0.1, 0.1, 0.5],
787+
vector_field_name="user_embedding",
788+
normalize_vector_distance=True,
789+
return_score=True,
790+
return_fields=["user", "credit_score", "age", "job", "location"],
791+
distance_threshold=1.2,
792+
)

tests/unit/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from redisvl.utils.utils import (
1818
assert_no_warnings,
19+
denorm_cosine_distance,
1920
deprecated_argument,
2021
deprecated_function,
2122
norm_cosine_distance,
@@ -28,6 +29,17 @@ def test_norm_cosine_distance():
2829
assert norm_cosine_distance(input) == expected
2930

3031

32+
def test_denorm_cosine_distance():
33+
input = 0
34+
expected = 2
35+
assert denorm_cosine_distance(input) == expected
36+
37+
38+
def test_norm_denorm_cosine():
39+
input = 0.6
40+
assert input == round(denorm_cosine_distance(norm_cosine_distance(input)), 6)
41+
42+
3143
def test_even_number_of_elements():
3244
"""Test with an even number of elements"""
3345
values = ["key1", "value1", "key2", "value2"]

0 commit comments

Comments
 (0)