Skip to content

Commit 9c88680

Browse files
committed
add tests edit bounds
1 parent 4a67775 commit 9c88680

File tree

7 files changed

+58
-27
lines changed

7 files changed

+58
-27
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,9 @@ def set_threshold(self, distance_threshold: float) -> None:
238238
Raises:
239239
ValueError: If the threshold is not between 0 and 1.
240240
"""
241-
if not 0 <= float(distance_threshold) <= 1:
241+
if not 0 <= float(distance_threshold) <= 2:
242242
raise ValueError(
243-
f"Distance must be between 0 and 1, got {distance_threshold}"
243+
f"Distance must be between 0 and 2, got {distance_threshold}"
244244
)
245245
self._distance_threshold = float(distance_threshold)
246246

redisvl/index/index.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@
1616
Union,
1717
)
1818

19-
from redisvl.utils.utils import (
20-
deprecated_argument,
21-
deprecated_function,
22-
norm_cosine_distance,
23-
sync_wrapper,
24-
)
19+
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
2520

2621
if TYPE_CHECKING:
2722
from redis.commands.search.aggregation import AggregateResult
@@ -35,13 +30,7 @@
3530

3631
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
3732
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
38-
from redisvl.query import (
39-
BaseQuery,
40-
CountQuery,
41-
FilterQuery,
42-
VectorQuery,
43-
VectorRangeQuery,
44-
)
33+
from redisvl.query import BaseQuery, BaseVectorQuery, CountQuery, FilterQuery
4534
from redisvl.query.filter import FilterExpression
4635
from redisvl.redis.connection import (
4736
RedisConnectionFactory,
@@ -92,9 +81,7 @@ def process_results(
9281
and not query._return_fields # type: ignore
9382
)
9483

95-
if (
96-
isinstance(query, VectorQuery) or isinstance(query, VectorRangeQuery)
97-
) and query._normalize_vector_distance:
84+
if (isinstance(query, BaseVectorQuery)) and query._normalize_vector_distance:
9885
dist_metric = VectorDistanceMetric(
9986
schema.fields[query._vector_field_name].attrs.distance_metric.upper() # type: ignore
10087
)

redisvl/query/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from redisvl.query.query import (
22
BaseQuery,
3+
BaseVectorQuery,
34
CountQuery,
45
FilterQuery,
56
RangeQuery,
@@ -9,6 +10,7 @@
910

1011
__all__ = [
1112
"BaseQuery",
13+
"BaseVectorQuery",
1214
"VectorQuery",
1315
"FilterQuery",
1416
"RangeQuery",

redisvl/query/query.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from redisvl.query.filter import FilterExpression
66
from redisvl.redis.utils import array_to_buffer
7+
from redisvl.utils.utils import denorm_cosine_distance
78

89

910
class BaseQuery(RedisQuery):
@@ -174,6 +175,8 @@ class BaseVectorQuery:
174175
DISTANCE_ID: str = "vector_distance"
175176
VECTOR_PARAM: str = "vector"
176177

178+
_normalize_vector_distance: bool = False
179+
177180

178181
class VectorQuery(BaseVectorQuery, BaseQuery):
179182
def __init__(
@@ -383,6 +386,16 @@ def set_distance_threshold(self, distance_threshold: float):
383386
"""
384387
if not isinstance(distance_threshold, (float, int)):
385388
raise TypeError("distance_threshold must be of type int or float")
389+
390+
if self._normalize_vector_distance:
391+
if distance_threshold < 0 or distance_threshold > 1:
392+
raise ValueError(
393+
"distance_threshold must be between 0 and 1 when normalize_vector_distance is set to True"
394+
)
395+
396+
# User sets normalized value 0-1 denormalize for use in DB
397+
distance_threshold = denorm_cosine_distance(distance_threshold)
398+
386399
self._distance_threshold = distance_threshold
387400

388401
@property

redisvl/utils/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,12 @@ def norm_cosine_distance(value: float) -> float:
197197
"""
198198
Normalize the cosine distance to a similarity score between 0 and 1.
199199
"""
200-
return (2 - value) / 2
200+
return max((2 - value) / 2, 0)
201+
202+
203+
def denorm_cosine_distance(value: float) -> float:
204+
"""Denormalize the distance threshold from [0, 1] to [0, 1] for our db."""
205+
return max(2 - 2 * value, 0)
201206

202207

203208
def norm_l2_distance(value: float) -> float:

tests/integration/test_query.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from redis.commands.search.result import Result
55

66
from redisvl.index import SearchIndex
7-
from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery
7+
from redisvl.query import CountQuery, FilterQuery, VectorQuery, VectorRangeQuery
88
from redisvl.query.filter import (
99
FilterExpression,
1010
Geo,
@@ -104,7 +104,7 @@ def sorted_filter_query():
104104

105105
@pytest.fixture
106106
def normalized_range_query():
107-
return RangeQuery(
107+
return VectorRangeQuery(
108108
vector=[0.1, 0.1, 0.5],
109109
vector_field_name="user_embedding",
110110
normalize_vector_distance=True,
@@ -116,7 +116,7 @@ def normalized_range_query():
116116

117117
@pytest.fixture
118118
def range_query():
119-
return RangeQuery(
119+
return VectorRangeQuery(
120120
vector=[0.1, 0.1, 0.5],
121121
vector_field_name="user_embedding",
122122
return_fields=["user", "credit_score", "age", "job", "location"],
@@ -126,7 +126,7 @@ def range_query():
126126

127127
@pytest.fixture
128128
def sorted_range_query():
129-
return RangeQuery(
129+
return VectorRangeQuery(
130130
vector=[0.1, 0.1, 0.5],
131131
vector_field_name="user_embedding",
132132
return_fields=["user", "credit_score", "age", "job", "location"],
@@ -271,7 +271,7 @@ def test_search_and_query(index):
271271

272272

273273
def test_range_query(index):
274-
r = RangeQuery(
274+
r = VectorRangeQuery(
275275
vector=[0.1, 0.1, 0.5],
276276
vector_field_name="user_embedding",
277277
return_fields=["user", "credit_score", "age", "job"],
@@ -342,7 +342,7 @@ def search(
342342
assert doc.location == location
343343

344344
# if range query, test results by distance threshold
345-
if isinstance(query, RangeQuery):
345+
if isinstance(query, VectorRangeQuery):
346346
for doc in results.docs:
347347
print(doc.vector_distance)
348348
assert float(doc.vector_distance) <= distance_threshold
@@ -353,7 +353,7 @@ def search(
353353

354354
# check results are in sorted order
355355
if sort:
356-
if isinstance(query, RangeQuery):
356+
if isinstance(query, VectorRangeQuery):
357357
assert [int(doc.age) for doc in results.docs] == [12, 14, 18, 100]
358358
else:
359359
assert [int(doc.age) for doc in results.docs] == [
@@ -369,7 +369,7 @@ def search(
369369

370370
@pytest.fixture(
371371
params=["vector_query", "filter_query", "range_query"],
372-
ids=["VectorQuery", "FilterQuery", "RangeQuery"],
372+
ids=["VectorQuery", "FilterQuery", "VectorRangeQuery"],
373373
)
374374
def query(request):
375375
return request.getfixturevalue(request.param)
@@ -650,3 +650,15 @@ def test_range_query_normalize_cosine_distance(index, normalized_range_query):
650650

651651
for r in res:
652652
assert 0 <= float(r["vector_distance"]) <= 1
653+
654+
655+
def test_range_query_normalize_bad_input(index):
656+
with pytest.raises(ValueError):
657+
VectorRangeQuery(
658+
vector=[0.1, 0.1, 0.5],
659+
vector_field_name="user_embedding",
660+
normalize_vector_distance=True,
661+
return_score=True,
662+
return_fields=["user", "credit_score", "age", "job", "location"],
663+
distance_threshold=1.2,
664+
)

tests/unit/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from redisvl.utils.utils import (
1313
assert_no_warnings,
14+
denorm_cosine_distance,
1415
deprecated_argument,
1516
deprecated_function,
1617
norm_cosine_distance,
@@ -23,6 +24,17 @@ def test_norm_cosine_distance():
2324
assert norm_cosine_distance(input) == expected
2425

2526

27+
def test_denorm_cosine_distance():
28+
input = 0
29+
expected = 2
30+
assert denorm_cosine_distance(input) == expected
31+
32+
33+
def test_norm_denorm_cosine():
34+
input = 0.6
35+
assert input == round(denorm_cosine_distance(norm_cosine_distance(input)), 6)
36+
37+
2638
def test_even_number_of_elements():
2739
"""Test with an even number of elements"""
2840
values = ["key1", "value1", "key2", "value2"]

0 commit comments

Comments
 (0)