Skip to content

Commit 4a67775

Browse files
committed
update for vector_norm map
1 parent 2172b9f commit 4a67775

File tree

9 files changed

+77
-31
lines changed

9 files changed

+77
-31
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
SemanticCacheIndexSchema,
2323
)
2424
from redisvl.index import AsyncSearchIndex, SearchIndex
25-
from redisvl.query import RangeQuery
25+
from redisvl.query import VectorRangeQuery
2626
from redisvl.query.filter import FilterExpression
2727
from redisvl.query.query import BaseQuery
2828
from redisvl.redis.connection import RedisConnectionFactory
@@ -390,7 +390,7 @@ def check(
390390
vector = vector or self._vectorize_prompt(prompt)
391391
self._check_vector_dims(vector)
392392

393-
query = RangeQuery(
393+
query = VectorRangeQuery(
394394
vector=vector,
395395
vector_field_name=CACHE_VECTOR_FIELD_NAME,
396396
return_fields=self.return_fields,
@@ -473,14 +473,15 @@ async def acheck(
473473
vector = vector or await self._avectorize_prompt(prompt)
474474
self._check_vector_dims(vector)
475475

476-
query = RangeQuery(
476+
query = VectorRangeQuery(
477477
vector=vector,
478478
vector_field_name=CACHE_VECTOR_FIELD_NAME,
479479
return_fields=self.return_fields,
480480
distance_threshold=distance_threshold,
481481
num_results=num_results,
482482
return_score=True,
483483
filter_expression=filter_expression,
484+
normalize_vector_distance=True,
484485
)
485486

486487
# Search the cache!

redisvl/extensions/router/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Route(BaseModel):
1818
"""List of reference phrases for the route."""
1919
metadata: Dict[str, Any] = Field(default={})
2020
"""Metadata associated with the route."""
21-
distance_threshold: Annotated[float, Field(strict=True, gt=0, le=1)] = 0.5
21+
distance_threshold: Annotated[float, Field(strict=True, gt=0, le=2)] = 0.5
2222
"""Distance threshold for matching the route."""
2323

2424
@field_validator("name")

redisvl/extensions/router/semantic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
SemanticRouterIndexSchema,
1818
)
1919
from redisvl.index import SearchIndex
20-
from redisvl.query import RangeQuery
20+
from redisvl.query import VectorRangeQuery
2121
from redisvl.redis.utils import convert_bytes, hashify, make_dict
2222
from redisvl.utils.log import get_logger
2323
from redisvl.utils.utils import deprecated_argument, model_to_dict
@@ -237,7 +237,7 @@ def _distance_threshold_filter(self) -> str:
237237

238238
def _build_aggregate_request(
239239
self,
240-
vector_range_query: RangeQuery,
240+
vector_range_query: VectorRangeQuery,
241241
aggregation_method: DistanceAggregationMethod,
242242
max_k: int,
243243
) -> AggregateRequest:
@@ -279,7 +279,7 @@ def _get_route_matches(
279279
# therefore you might take the max_threshold and further refine from there.
280280
distance_threshold = max(route.distance_threshold for route in self.routes)
281281

282-
vector_range_query = RangeQuery(
282+
vector_range_query = VectorRangeQuery(
283283
vector=vector,
284284
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
285285
distance_threshold=float(distance_threshold),

redisvl/index/index.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
)
5050
from redisvl.redis.utils import convert_bytes
5151
from redisvl.schema import IndexSchema, StorageType
52-
from redisvl.schema.fields import VectorDistanceMetric
52+
from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric
5353
from redisvl.utils.log import get_logger
5454

5555
logger = get_logger(__name__)
@@ -92,12 +92,20 @@ def process_results(
9292
and not query._return_fields # type: ignore
9393
)
9494

95-
normalize_cosine_distance = (
96-
(isinstance(query, VectorQuery) or isinstance(query, VectorRangeQuery))
97-
and query._normalize_cosine_distance
98-
and schema.fields[query._vector_field_name].attrs.distance_metric # type: ignore
99-
== VectorDistanceMetric.COSINE
100-
)
95+
if (
96+
isinstance(query, VectorQuery) or isinstance(query, VectorRangeQuery)
97+
) and query._normalize_vector_distance:
98+
dist_metric = VectorDistanceMetric(
99+
schema.fields[query._vector_field_name].attrs.distance_metric.upper() # type: ignore
100+
)
101+
if dist_metric == VectorDistanceMetric.IP:
102+
warnings.warn(
103+
"Attempting to normalize inner product distance metric. Use cosine distance instead which is normalized inner product by definition."
104+
)
105+
106+
norm_fn = VECTOR_NORM_MAP[dist_metric.value]
107+
else:
108+
norm_fn = None
101109

102110
# Process records
103111
def _process(doc: "Document") -> Dict[str, Any]:
@@ -112,10 +120,10 @@ def _process(doc: "Document") -> Dict[str, Any]:
112120
return {"id": doc_dict.get("id"), **json_data}
113121
raise ValueError(f"Unable to parse json data from Redis {json_data}")
114122

115-
if normalize_cosine_distance:
123+
if norm_fn:
116124
# convert float back to string to be consistent
117125
doc_dict[query.DISTANCE_ID] = str( # type: ignore
118-
norm_cosine_distance(float(doc_dict[query.DISTANCE_ID])) # type: ignore
126+
norm_fn(float(doc_dict[query.DISTANCE_ID])) # type: ignore
119127
)
120128

121129
# Remove 'payload' if present

redisvl/query/query.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def __init__(
188188
dialect: int = 2,
189189
sort_by: Optional[str] = None,
190190
in_order: bool = False,
191-
normalize_cosine_distance: bool = False,
191+
normalize_vector_distance: bool = False,
192192
):
193193
"""A query for running a vector search along with an optional filter
194194
expression.
@@ -214,9 +214,12 @@ def __init__(
214214
in_order (bool): Requires the terms in the field to have
215215
the same order as the terms in the query filter, regardless of
216216
the offsets between them. Defaults to False.
217-
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
218-
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
219-
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
217+
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
218+
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
219+
COSINE distance returns a value between 0 and 2. IP returns a value determined by
220+
the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance
221+
to a similarity score between 0 and 1. Note: setting this flag to true for IP will
222+
throw a warning since by definition COSINE similarity is normalized IP.
220223
221224
Raises:
222225
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -228,7 +231,7 @@ def __init__(
228231
self._vector_field_name = vector_field_name
229232
self._dtype = dtype
230233
self._num_results = num_results
231-
self._normalize_cosine_distance = normalize_cosine_distance
234+
self._normalize_vector_distance = normalize_vector_distance
232235
self.set_filter(filter_expression)
233236
query_string = self._build_query_string()
234237

@@ -289,7 +292,7 @@ def __init__(
289292
dialect: int = 2,
290293
sort_by: Optional[str] = None,
291294
in_order: bool = False,
292-
normalize_cosine_distance: bool = False,
295+
normalize_vector_distance: bool = False,
293296
):
294297
"""A query for running a filtered vector search based on semantic
295298
distance threshold.
@@ -318,9 +321,12 @@ def __init__(
318321
in_order (bool): Requires the terms in the field to have
319322
the same order as the terms in the query filter, regardless of
320323
the offsets between them. Defaults to False.
321-
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
322-
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
323-
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
324+
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
325+
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
326+
COSINE distance returns a value between 0 and 2. IP returns a value determined by
327+
the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance
328+
to a similarity score between 0 and 1. Note: setting this flag to true for IP will
329+
throw a warning since by definition COSINE similarity is normalized IP.
324330
325331
Raises:
326332
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -333,7 +339,7 @@ def __init__(
333339
self._vector_field_name = vector_field_name
334340
self._dtype = dtype
335341
self._num_results = num_results
336-
self._normalize_cosine_distance = normalize_cosine_distance
342+
self._normalize_vector_distance = normalize_vector_distance
337343
self.set_distance_threshold(distance_threshold)
338344
self.set_filter(filter_expression)
339345
query_string = self._build_query_string()

redisvl/schema/fields.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
from redis.commands.search.field import TextField as RedisTextField
1717
from redis.commands.search.field import VectorField as RedisVectorField
1818

19+
from redisvl.utils.utils import norm_cosine_distance, norm_l2_distance
20+
21+
VECTOR_NORM_MAP = {
22+
"COSINE": norm_cosine_distance,
23+
"L2": norm_l2_distance,
24+
"IP": None, # normalized inner product is cosine similarity by definition
25+
}
26+
1927

2028
class FieldTypes(str, Enum):
2129
TAG = "tag"

redisvl/utils/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,10 @@ def norm_cosine_distance(value: float) -> float:
198198
Normalize the cosine distance to a similarity score between 0 and 1.
199199
"""
200200
return (2 - value) / 2
201+
202+
203+
def norm_l2_distance(value: float) -> float:
204+
"""
205+
Normalize the L2 distance.
206+
"""
207+
return 1 / (1 + value)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def sample_data(sample_datetimes):
141141
"last_updated": sample_datetimes["high"].timestamp(),
142142
"credit_score": "medium",
143143
"location": "-110.0839,37.3861",
144-
"user_embedding": [0.9, 0.9, 0.1],
144+
"user_embedding": [-0.1, -0.1, -0.5],
145145
},
146146
]
147147

tests/integration/test_query.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def vector_query():
2424
return VectorQuery(
2525
vector=[0.1, 0.1, 0.5],
2626
vector_field_name="user_embedding",
27+
return_score=True,
2728
return_fields=[
2829
"user",
2930
"credit_score",
@@ -57,7 +58,7 @@ def normalized_vector_query():
5758
return VectorQuery(
5859
vector=[0.1, 0.1, 0.5],
5960
vector_field_name="user_embedding",
60-
normalize_cosine_distance=True,
61+
normalize_vector_distance=True,
6162
return_score=True,
6263
return_fields=[
6364
"user",
@@ -106,7 +107,7 @@ def normalized_range_query():
106107
return RangeQuery(
107108
vector=[0.1, 0.1, 0.5],
108109
vector_field_name="user_embedding",
109-
normalize_cosine_distance=True,
110+
normalize_vector_distance=True,
110111
return_score=True,
111112
return_fields=["user", "credit_score", "age", "job", "location"],
112113
distance_threshold=0.2,
@@ -621,13 +622,28 @@ def test_query_normalize_cosine_distance(index, normalized_vector_query):
621622
assert 0 <= float(r["vector_distance"]) <= 1
622623

623624

624-
def test_query_normalize_cosine_distance_lp_distance(L2_index, normalized_vector_query):
625+
def test_query_cosine_distance_un_normalized(index, vector_query):
625626

626-
res = L2_index.query(normalized_vector_query)
627+
res = index.query(vector_query)
628+
629+
assert any(float(r["vector_distance"]) > 1 for r in res)
630+
631+
632+
def test_query_l2_distance_un_normalized(L2_index, vector_query):
633+
634+
res = L2_index.query(vector_query)
627635

628636
assert any(float(r["vector_distance"]) > 1 for r in res)
629637

630638

639+
def test_query_l2_distance_normalized(L2_index, normalized_vector_query):
640+
641+
res = L2_index.query(normalized_vector_query)
642+
643+
for r in res:
644+
assert 0 <= float(r["vector_distance"]) <= 1
645+
646+
631647
def test_range_query_normalize_cosine_distance(index, normalized_range_query):
632648

633649
res = index.query(normalized_range_query)

0 commit comments

Comments
 (0)