Skip to content

Commit e96d739

Browse files
authored
Consider vectorizer the owner of dtype (#267)
The `dtype` argument to `SemanticCache`, `SemanticRouter`, and `SemanticSessionManager` is intended to specify the vectorizer's data type. However, the constructors for these classes allow passing in both a vectorizer instance (with dtype already set) and `dtype`, in which case, the vectorizer may have a different `dtype` than the `dtype` used in the index schema. This PR works on eliminating this possibility by adding validation and treating the vectorizer as the true "owner" of `dtype`. We also begin the deprecation process for standalone `dtype` arguments, guiding users to pass in vectorizer instances instead if they require customizing the vectorizer's dtype. Finally, we make vectorizer validation consistent across these classes.
1 parent ca6396c commit e96d739

File tree

9 files changed

+335
-53
lines changed

9 files changed

+335
-53
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
from redisvl.index import AsyncSearchIndex, SearchIndex
2323
from redisvl.query import RangeQuery
2424
from redisvl.query.filter import FilterExpression
25-
from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims
25+
from redisvl.utils.utils import (
26+
current_timestamp,
27+
deprecated_argument,
28+
serialize,
29+
validate_vector_dims,
30+
)
2631
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
2732

2833

@@ -32,6 +37,7 @@ class SemanticCache(BaseLLMCache):
3237
_index: SearchIndex
3338
_aindex: Optional[AsyncSearchIndex] = None
3439

40+
@deprecated_argument("dtype", "vectorizer")
3541
def __init__(
3642
self,
3743
name: str = "llmcache",
@@ -86,12 +92,26 @@ def __init__(
8692
else:
8793
prefix = name
8894

89-
# Set vectorizer default
90-
if vectorizer is None:
95+
dtype = kwargs.get("dtype")
96+
97+
# Validate a provided vectorizer or set the default
98+
if vectorizer:
99+
if not isinstance(vectorizer, BaseVectorizer):
100+
raise TypeError("Must provide a valid redisvl.vectorizer class.")
101+
if dtype and vectorizer.dtype != dtype:
102+
raise ValueError(
103+
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
104+
)
105+
else:
106+
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
107+
91108
vectorizer = HFTextVectorizer(
92-
model="sentence-transformers/all-mpnet-base-v2"
109+
model="sentence-transformers/all-mpnet-base-v2",
110+
**vectorizer_kwargs,
93111
)
94112

113+
self._vectorizer = vectorizer
114+
95115
# Process fields and other settings
96116
self.set_threshold(distance_threshold)
97117
self.return_fields = [
@@ -104,9 +124,8 @@ def __init__(
104124
]
105125

106126
# Create semantic cache schema and index
107-
dtype = kwargs.get("dtype", "float32")
108127
schema = SemanticCacheIndexSchema.from_params(
109-
name, prefix, vectorizer.dims, dtype
128+
name, prefix, vectorizer.dims, vectorizer.dtype
110129
)
111130
schema = self._modify_schema(schema, filterable_fields)
112131
self._index = SearchIndex(schema=schema)
@@ -128,20 +147,9 @@ def __init__(
128147
"If you wish to overwrite the index schema, set overwrite=True during initialization."
129148
)
130149

131-
# Create the search index
150+
# Create the search index in Redis
132151
self._index.create(overwrite=overwrite, drop=False)
133152

134-
# Initialize and validate vectorizer
135-
if not isinstance(vectorizer, BaseVectorizer):
136-
raise TypeError("Must provide a valid redisvl.vectorizer class.")
137-
138-
validate_vector_dims(
139-
vectorizer.dims,
140-
self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore
141-
)
142-
self._vectorizer = vectorizer
143-
self._dtype = self.index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr]
144-
145153
def _modify_schema(
146154
self,
147155
schema: SemanticCacheIndexSchema,
@@ -290,7 +298,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
290298
if not isinstance(prompt, str):
291299
raise TypeError("Prompt must be a string.")
292300

293-
return self._vectorizer.embed(prompt, dtype=self._dtype)
301+
return self._vectorizer.embed(prompt)
294302

295303
async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
296304
"""Converts a text prompt to its vector representation using the
@@ -372,7 +380,7 @@ def check(
372380
num_results=num_results,
373381
return_score=True,
374382
filter_expression=filter_expression,
375-
dtype=self._dtype,
383+
dtype=self._vectorizer.dtype,
376384
)
377385

378386
# Search the cache!
@@ -543,7 +551,7 @@ def store(
543551
# Load cache entry with TTL
544552
ttl = ttl or self._ttl
545553
keys = self._index.load(
546-
data=[cache_entry.to_dict(self._dtype)],
554+
data=[cache_entry.to_dict(self._vectorizer.dtype)],
547555
ttl=ttl,
548556
id_field=ENTRY_ID_FIELD_NAME,
549557
)
@@ -607,7 +615,7 @@ async def astore(
607615
# Load cache entry with TTL
608616
ttl = ttl or self._ttl
609617
keys = await aindex.load(
610-
data=[cache_entry.to_dict(self._dtype)],
618+
data=[cache_entry.to_dict(self._vectorizer.dtype)],
611619
ttl=ttl,
612620
id_field=ENTRY_ID_FIELD_NAME,
613621
)

redisvl/extensions/router/semantic.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from redisvl.query import RangeQuery
2121
from redisvl.redis.utils import convert_bytes, hashify, make_dict
2222
from redisvl.utils.log import get_logger
23-
from redisvl.utils.utils import model_to_dict
23+
from redisvl.utils.utils import deprecated_argument, model_to_dict
2424
from redisvl.utils.vectorize import (
2525
BaseVectorizer,
2626
HFTextVectorizer,
@@ -47,6 +47,7 @@ class SemanticRouter(BaseModel):
4747
class Config:
4848
arbitrary_types_allowed = True
4949

50+
@deprecated_argument("dtype", "vectorizer")
5051
def __init__(
5152
self,
5253
name: str,
@@ -72,9 +73,19 @@ def __init__(
7273
connection_kwargs (Dict[str, Any]): The connection arguments
7374
for the redis client. Defaults to empty {}.
7475
"""
75-
# Set vectorizer default
76-
if vectorizer is None:
77-
vectorizer = HFTextVectorizer()
76+
dtype = kwargs.get("dtype")
77+
78+
# Validate a provided vectorizer or set the default
79+
if vectorizer:
80+
if not isinstance(vectorizer, BaseVectorizer):
81+
raise TypeError("Must provide a valid redisvl.vectorizer class.")
82+
if dtype and vectorizer.dtype != dtype:
83+
raise ValueError(
84+
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
85+
)
86+
else:
87+
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
88+
vectorizer = HFTextVectorizer(**vectorizer_kwargs)
7889

7990
if routing_config is None:
8091
routing_config = RoutingConfig()
@@ -85,11 +96,9 @@ def __init__(
8596
vectorizer=vectorizer,
8697
routing_config=routing_config,
8798
)
88-
dtype = kwargs.get("dtype", "float32")
89-
self._initialize_index(
90-
redis_client, redis_url, overwrite, dtype, **connection_kwargs
91-
)
99+
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)
92100

101+
@deprecated_argument("dtype")
93102
def _initialize_index(
94103
self,
95104
redis_client: Optional[Redis] = None,
@@ -100,7 +109,7 @@ def _initialize_index(
100109
):
101110
"""Initialize the search index and handle Redis connection."""
102111
schema = SemanticRouterIndexSchema.from_params(
103-
self.name, self.vectorizer.dims, dtype
112+
self.name, self.vectorizer.dims, self.vectorizer.dtype
104113
)
105114
self._index = SearchIndex(schema=schema)
106115

@@ -169,9 +178,7 @@ def _add_routes(self, routes: List[Route]):
169178
for route in routes:
170179
# embed route references as a single batch
171180
reference_vectors = self.vectorizer.embed_many(
172-
[reference for reference in route.references],
173-
as_buffer=True,
174-
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
181+
[reference for reference in route.references], as_buffer=True
175182
)
176183
# set route references
177184
for i, reference in enumerate(route.references):
@@ -248,7 +255,6 @@ def _classify_route(
248255
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
249256
distance_threshold=distance_threshold,
250257
return_fields=["route_name"],
251-
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
252258
)
253259

254260
aggregate_request = self._build_aggregate_request(
@@ -301,7 +307,6 @@ def _classify_multi_route(
301307
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
302308
distance_threshold=distance_threshold,
303309
return_fields=["route_name"],
304-
dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
305310
)
306311
aggregate_request = self._build_aggregate_request(
307312
vector_range_query, aggregation_method, max_k

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
from redisvl.index import SearchIndex
2020
from redisvl.query import FilterQuery, RangeQuery
2121
from redisvl.query.filter import Tag
22-
from redisvl.utils.utils import validate_vector_dims
22+
from redisvl.utils.utils import deprecated_argument, validate_vector_dims
2323
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
2424

2525

2626
class SemanticSessionManager(BaseSessionManager):
2727

28+
@deprecated_argument("dtype", "vectorizer")
2829
def __init__(
2930
self,
3031
name: str,
@@ -70,16 +71,30 @@ def __init__(
7071
super().__init__(name, session_tag)
7172

7273
prefix = prefix or name
74+
dtype = kwargs.get("dtype")
7375

74-
self._vectorizer = vectorizer or HFTextVectorizer(
75-
model="sentence-transformers/msmarco-distilbert-cos-v5"
76-
)
76+
# Validate a provided vectorizer or set the default
77+
if vectorizer:
78+
if not isinstance(vectorizer, BaseVectorizer):
79+
raise TypeError("Must provide a valid redisvl.vectorizer class.")
80+
if dtype and vectorizer.dtype != dtype:
81+
raise ValueError(
82+
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
83+
)
84+
else:
85+
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
86+
87+
vectorizer = HFTextVectorizer(
88+
model="sentence-transformers/msmarco-distilbert-cos-v5",
89+
**vectorizer_kwargs,
90+
)
91+
92+
self._vectorizer = vectorizer
7793

7894
self.set_distance_threshold(distance_threshold)
7995

80-
dtype = kwargs.get("dtype", "float32")
8196
schema = SemanticSessionIndexSchema.from_params(
82-
name, prefix, self._vectorizer.dims, dtype
97+
name, prefix, self._vectorizer.dims, vectorizer.dtype
8398
)
8499

85100
self._index = SearchIndex(schema=schema)
@@ -215,7 +230,7 @@ def get_relevant(
215230
num_results=top_k,
216231
return_score=True,
217232
filter_expression=session_filter,
218-
dtype=self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
233+
dtype=self._vectorizer.dtype,
219234
)
220235
messages = self._index.query(query)
221236

@@ -341,7 +356,7 @@ def add_messages(
341356
if TOOL_FIELD_NAME in message:
342357
chat_message.tool_call_id = message[TOOL_FIELD_NAME]
343358

344-
chat_messages.append(chat_message.to_dict(dtype=self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.datatype)) # type: ignore[union-attr]
359+
chat_messages.append(chat_message.to_dict(dtype=self._vectorizer.dtype))
345360

346361
self._index.load(data=chat_messages, id_field=ID_FIELD_NAME)
347362

redisvl/utils/utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
22
from enum import Enum
3+
from functools import wraps
34
from time import time
4-
from typing import Any, Dict
5+
from typing import Any, Callable, Dict, Optional
56
from uuid import uuid4
7+
from warnings import warn
68

79
from pydantic.v1 import BaseModel
810

@@ -57,3 +59,32 @@ def serialize(data: Dict[str, Any]) -> str:
5759
def deserialize(data: str) -> Dict[str, Any]:
5860
"""Deserialize the input from a string."""
5961
return json.loads(data)
62+
63+
64+
def deprecated_argument(argument: str, replacement: Optional[str] = None) -> Callable:
65+
"""
66+
Decorator to warn if a deprecated argument is passed.
67+
68+
When the wrapped function is called, the decorator will warn if the
69+
deprecated argument is passed as an argument or keyword argument.
70+
"""
71+
72+
message = f"Argument {argument} is deprecated and will be removed in the next major release."
73+
if replacement:
74+
message += f" Use {replacement} instead."
75+
76+
def wrapper(func):
77+
@wraps(func)
78+
def inner(*args, **kwargs):
79+
argument_names = func.__code__.co_varnames
80+
81+
if argument in argument_names:
82+
warn(message, DeprecationWarning, stacklevel=2)
83+
elif argument in kwargs:
84+
warn(message, DeprecationWarning, stacklevel=2)
85+
86+
return func(*args, **kwargs)
87+
88+
return inner
89+
90+
return wrapper

tests/integration/test_llmcache.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from collections import namedtuple
44
from time import sleep, time
5+
import warnings
56

67
import pytest
78
from pydantic.v1 import ValidationError
@@ -71,6 +72,13 @@ def cache_with_redis_client(vectorizer, client):
7172
cache_instance._index.delete(True) # Clean up index
7273

7374

75+
@pytest.fixture(autouse=True)
76+
def disable_deprecation_warnings():
77+
with warnings.catch_warnings():
78+
warnings.simplefilter("ignore")
79+
yield
80+
81+
7482
def test_bad_ttl(cache):
7583
with pytest.raises(ValueError):
7684
cache.set_ttl(2.5)
@@ -884,3 +892,35 @@ def test_bad_dtype_connecting_to_existing_cache(redis_url):
884892
bad_type = SemanticCache(
885893
name="float64_cache", dtype="float16", redis_url=redis_url
886894
)
895+
896+
897+
def test_vectorizer_dtype_mismatch():
898+
with pytest.raises(ValueError):
899+
SemanticCache(
900+
name="test_dtype_mismatch",
901+
dtype="float32",
902+
vectorizer=HFTextVectorizer(dtype="float16"),
903+
overwrite=True,
904+
)
905+
906+
907+
def test_invalid_vectorizer():
908+
with pytest.raises(TypeError):
909+
SemanticCache(
910+
name="test_invalid_vectorizer",
911+
vectorizer="invalid_vectorizer", # type: ignore
912+
overwrite=True,
913+
)
914+
915+
916+
def test_passes_through_dtype_to_default_vectorizer():
917+
# The default is float32, so we should see float64 if we pass it in.
918+
cache = SemanticCache(
919+
name="test_pass_through_dtype", dtype="float64", overwrite=True
920+
)
921+
assert cache._vectorizer.dtype == "float64"
922+
923+
924+
def test_deprecated_dtype_argument():
925+
with pytest.warns(DeprecationWarning):
926+
SemanticCache(name="test_deprecated_dtype", dtype="float32", overwrite=True)

0 commit comments

Comments
 (0)