Skip to content

Commit 3395c87

Browse files
committed
feat: add text field weights support to TextQuery (#360)
Adds the ability to specify weights for text fields in RedisVL queries, enabling users to prioritize certain fields over others in search results. - Support dictionary of field:weight mappings in TextQuery constructor - Maintain backward compatibility with single string field names - Add set_field_weights() method for dynamic weight updates - Generate proper Redis query syntax with weight modifiers - Comprehensive validation for positive numeric weights Example usage: ```python query = TextQuery(text="search", text_field_name={"title": 5.0}) query = TextQuery( text="search", text_field_name={"title": 3.0, "content": 1.5, "tags": 1.0} ) ``` - Add has_redisearch_module and has_redisearch_module_async helpers to conftest.py - Add skip_if_no_redisearch and skip_if_no_redisearch_async functions - Update test_no_proactive_module_checks.py to use shared helpers - Update test_semantic_router.py to check RediSearch availability in fixtures and tests - Update test_llmcache.py to check RediSearch availability in all cache fixtures - Update test_message_history.py to check RediSearch availability for semantic history - Ensure all tests that require RediSearch are properly skipped on Redis 6.2.6-v9 - BM25STD scorer is not available in Redis versions prior to 7.2.0. Add version check to skip these tests on older Redis versions.
1 parent a56f9a1 commit 3395c87

File tree

8 files changed

+514
-54
lines changed

8 files changed

+514
-54
lines changed

redisvl/query/query.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ class TextQuery(BaseQuery):
801801
def __init__(
802802
self,
803803
text: str,
804-
text_field_name: str,
804+
text_field_name: Union[str, Dict[str, float]],
805805
text_scorer: str = "BM25STD",
806806
filter_expression: Optional[Union[str, FilterExpression]] = None,
807807
return_fields: Optional[List[str]] = None,
@@ -817,7 +817,8 @@ def __init__(
817817
818818
Args:
819819
text (str): The text string to perform the text search with.
820-
text_field_name (str): The name of the document field to perform text search on.
820+
text_field_name (Union[str, Dict[str, float]]): The name of the document field to perform
821+
text search on, or a dictionary mapping field names to their weights.
821822
text_scorer (str, optional): The text scoring algorithm to use.
822823
Defaults to BM25STD. Options are {TFIDF, BM25STD, BM25, TFIDF.DOCNORM, DISMAX, DOCSCORE}.
823824
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
@@ -849,7 +850,7 @@ def __init__(
849850
TypeError: If stopwords is not a valid iterable set of strings.
850851
"""
851852
self._text = text
852-
self._text_field_name = text_field_name
853+
self._field_weights = self._parse_field_weights(text_field_name)
853854
self._num_results = num_results
854855

855856
self._set_stopwords(stopwords)
@@ -934,15 +935,97 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
934935
[token for token in tokens if token and token not in self._stopwords]
935936
)
936937

938+
def _parse_field_weights(
939+
self, field_spec: Union[str, Dict[str, float]]
940+
) -> Dict[str, float]:
941+
"""Parse the field specification into a weights dictionary.
942+
943+
Args:
944+
field_spec: Either a single field name or dictionary of field:weight mappings
945+
946+
Returns:
947+
Dictionary mapping field names to their weights
948+
"""
949+
if isinstance(field_spec, str):
950+
return {field_spec: 1.0}
951+
elif isinstance(field_spec, dict):
952+
# Validate all weights are numeric and positive
953+
for field, weight in field_spec.items():
954+
if not isinstance(field, str):
955+
raise TypeError(f"Field name must be a string, got {type(field)}")
956+
if not isinstance(weight, (int, float)):
957+
raise TypeError(
958+
f"Weight for field '{field}' must be numeric, got {type(weight)}"
959+
)
960+
if weight <= 0:
961+
raise ValueError(
962+
f"Weight for field '{field}' must be positive, got {weight}"
963+
)
964+
return field_spec
965+
else:
966+
raise TypeError(
967+
"text_field_name must be a string or dictionary of field:weight mappings"
968+
)
969+
970+
def set_field_weights(self, field_weights: Union[str, Dict[str, float]]):
971+
"""Set or update the field weights for the query.
972+
973+
Args:
974+
field_weights: Either a single field name or dictionary of field:weight mappings
975+
"""
976+
self._field_weights = self._parse_field_weights(field_weights)
977+
# Invalidate the query string
978+
self._built_query_string = None
979+
980+
@property
981+
def field_weights(self) -> Dict[str, float]:
982+
"""Get the field weights for the query.
983+
984+
Returns:
985+
Dictionary mapping field names to their weights
986+
"""
987+
return self._field_weights.copy()
988+
989+
@property
990+
def text_field_name(self) -> Union[str, Dict[str, float]]:
991+
"""Get the text field name(s) - for backward compatibility.
992+
993+
Returns:
994+
Either a single field name string (if only one field with weight 1.0)
995+
or a dictionary of field:weight mappings.
996+
"""
997+
if len(self._field_weights) == 1:
998+
field, weight = next(iter(self._field_weights.items()))
999+
if weight == 1.0:
1000+
return field
1001+
return self._field_weights.copy()
1002+
9371003
def _build_query_string(self) -> str:
9381004
"""Build the full query string for text search with optional filtering."""
9391005
filter_expression = self._filter_expression
9401006
if isinstance(filter_expression, FilterExpression):
9411007
filter_expression = str(filter_expression)
9421008

943-
text = (
944-
f"@{self._text_field_name}:({self._tokenize_and_escape_query(self._text)})"
945-
)
1009+
escaped_query = self._tokenize_and_escape_query(self._text)
1010+
1011+
# Build query parts for each field with its weight
1012+
field_queries = []
1013+
for field, weight in self._field_weights.items():
1014+
if weight == 1.0:
1015+
# Default weight doesn't need explicit weight syntax
1016+
field_queries.append(f"@{field}:({escaped_query})")
1017+
else:
1018+
# Use Redis weight syntax for non-default weights
1019+
field_queries.append(
1020+
f"@{field}:({escaped_query}) => {{ $weight: {weight} }}"
1021+
)
1022+
1023+
# Join multiple field queries with OR operator
1024+
if len(field_queries) == 1:
1025+
text = field_queries[0]
1026+
else:
1027+
text = "(" + " | ".join(field_queries) + ")"
1028+
9461029
if filter_expression and filter_expression != "*":
9471030
text += f" AND {filter_expression}"
9481031
return text

tests/conftest.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,26 @@ async def get_redis_version_async(client):
579579
return info["redis_version"]
580580

581581

582+
def has_redisearch_module(client):
583+
"""Check if RediSearch module is available."""
584+
try:
585+
# Try to list indices - this is a RediSearch command
586+
client.execute_command("FT._LIST")
587+
return True
588+
except Exception:
589+
return False
590+
591+
592+
async def has_redisearch_module_async(client):
593+
"""Check if RediSearch module is available (async)."""
594+
try:
595+
# Try to list indices - this is a RediSearch command
596+
await client.execute_command("FT._LIST")
597+
return True
598+
except Exception:
599+
return False
600+
601+
582602
def skip_if_redis_version_below(client, min_version: str, message: str = None):
583603
"""
584604
Skip test if Redis version is below minimum required.
@@ -609,3 +629,29 @@ async def skip_if_redis_version_below_async(
609629
if not compare_versions(redis_version, min_version):
610630
skip_msg = message or f"Redis version {redis_version} < {min_version} required"
611631
pytest.skip(skip_msg)
632+
633+
634+
def skip_if_no_redisearch(client, message: str = None):
635+
"""
636+
Skip test if RediSearch module is not available.
637+
638+
Args:
639+
client: Redis client instance
640+
message: Custom skip message
641+
"""
642+
if not has_redisearch_module(client):
643+
skip_msg = message or "RediSearch module not available"
644+
pytest.skip(skip_msg)
645+
646+
647+
async def skip_if_no_redisearch_async(client, message: str = None):
648+
"""
649+
Skip test if RediSearch module is not available (async version).
650+
651+
Args:
652+
client: Async Redis client instance
653+
message: Custom skip message
654+
"""
655+
if not await has_redisearch_module_async(client):
656+
skip_msg = message or "RediSearch module not available"
657+
pytest.skip(skip_msg)

tests/integration/test_llmcache.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from redisvl.index.index import AsyncSearchIndex, SearchIndex
1212
from redisvl.query.filter import Num, Tag, Text
1313
from redisvl.utils.vectorize import HFTextVectorizer
14+
from tests.conftest import skip_if_no_redisearch, skip_if_no_redisearch_async
1415

1516

1617
@pytest.fixture(scope="session")
@@ -19,7 +20,8 @@ def vectorizer():
1920

2021

2122
@pytest.fixture
22-
def cache(vectorizer, redis_url, worker_id):
23+
def cache(client, vectorizer, redis_url, worker_id):
24+
skip_if_no_redisearch(client)
2325
cache_instance = SemanticCache(
2426
name=f"llmcache_{worker_id}",
2527
vectorizer=vectorizer,
@@ -31,7 +33,8 @@ def cache(vectorizer, redis_url, worker_id):
3133

3234

3335
@pytest.fixture
34-
def cache_with_filters(vectorizer, redis_url, worker_id):
36+
def cache_with_filters(client, vectorizer, redis_url, worker_id):
37+
skip_if_no_redisearch(client)
3538
cache_instance = SemanticCache(
3639
name=f"llmcache_filters_{worker_id}",
3740
vectorizer=vectorizer,
@@ -44,7 +47,8 @@ def cache_with_filters(vectorizer, redis_url, worker_id):
4447

4548

4649
@pytest.fixture
47-
def cache_no_cleanup(vectorizer, redis_url, worker_id):
50+
def cache_no_cleanup(client, vectorizer, redis_url, worker_id):
51+
skip_if_no_redisearch(client)
4852
cache_instance = SemanticCache(
4953
name=f"llmcache_no_cleanup_{worker_id}",
5054
vectorizer=vectorizer,
@@ -55,7 +59,8 @@ def cache_no_cleanup(vectorizer, redis_url, worker_id):
5559

5660

5761
@pytest.fixture
58-
def cache_with_ttl(vectorizer, redis_url, worker_id):
62+
def cache_with_ttl(client, vectorizer, redis_url, worker_id):
63+
skip_if_no_redisearch(client)
5964
cache_instance = SemanticCache(
6065
name=f"llmcache_ttl_{worker_id}",
6166
vectorizer=vectorizer,
@@ -69,6 +74,7 @@ def cache_with_ttl(vectorizer, redis_url, worker_id):
6974

7075
@pytest.fixture
7176
def cache_with_redis_client(vectorizer, client, worker_id):
77+
skip_if_no_redisearch(client)
7278
cache_instance = SemanticCache(
7379
name=f"llmcache_client_{worker_id}",
7480
vectorizer=vectorizer,
@@ -750,7 +756,8 @@ def test_cache_filtering(cache_with_filters):
750756
assert len(results) == 0
751757

752758

753-
def test_cache_bad_filters(vectorizer, redis_url, worker_id):
759+
def test_cache_bad_filters(client, vectorizer, redis_url, worker_id):
760+
skip_if_no_redisearch(client)
754761
with pytest.raises(ValueError):
755762
cache_instance = SemanticCache(
756763
name=f"test_bad_filters_1_{worker_id}",
@@ -819,6 +826,7 @@ def test_complex_filters(cache_with_filters):
819826

820827

821828
def test_cache_index_overwrite(client, redis_url, worker_id, hf_vectorizer):
829+
skip_if_no_redisearch(client)
822830
# Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly
823831
redis_version = client.info()["redis_version"]
824832
if redis_version.startswith("6.2"):
@@ -921,7 +929,8 @@ def test_no_key_collision_on_identical_prompts(redis_url, worker_id, hf_vectoriz
921929
assert len(filtered_results) == 2
922930

923931

924-
def test_create_cache_with_different_vector_types(worker_id, redis_url):
932+
def test_create_cache_with_different_vector_types(client, worker_id, redis_url):
933+
skip_if_no_redisearch(client)
925934
try:
926935
bfloat_cache = SemanticCache(
927936
name=f"bfloat_cache_{worker_id}", dtype="bfloat16", redis_url=redis_url
@@ -951,6 +960,7 @@ def test_create_cache_with_different_vector_types(worker_id, redis_url):
951960

952961

953962
def test_bad_dtype_connecting_to_existing_cache(client, redis_url, worker_id):
963+
skip_if_no_redisearch(client)
954964
# Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly
955965
redis_version = client.info()["redis_version"]
956966
if redis_version.startswith("6.2"):
@@ -1021,7 +1031,10 @@ def test_deprecated_dtype_argument(redis_url, worker_id):
10211031

10221032

10231033
@pytest.mark.asyncio
1024-
async def test_cache_async_context_manager(redis_url, worker_id, hf_vectorizer):
1034+
async def test_cache_async_context_manager(
1035+
async_client, redis_url, worker_id, hf_vectorizer
1036+
):
1037+
await skip_if_no_redisearch_async(async_client)
10251038
async with SemanticCache(
10261039
name=f"test_cache_async_context_manager_{worker_id}",
10271040
redis_url=redis_url,
@@ -1034,8 +1047,9 @@ async def test_cache_async_context_manager(redis_url, worker_id, hf_vectorizer):
10341047

10351048
@pytest.mark.asyncio
10361049
async def test_cache_async_context_manager_with_exception(
1037-
redis_url, worker_id, hf_vectorizer
1050+
async_client, redis_url, worker_id, hf_vectorizer
10381051
):
1052+
await skip_if_no_redisearch_async(async_client)
10391053
try:
10401054
async with SemanticCache(
10411055
name=f"test_cache_async_context_manager_with_exception_{worker_id}",

tests/integration/test_message_history.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from redisvl.extensions.constants import ID_FIELD_NAME
77
from redisvl.extensions.message_history import MessageHistory, SemanticMessageHistory
8+
from tests.conftest import skip_if_no_redisearch
89

910

1011
@pytest.fixture
@@ -21,6 +22,7 @@ def standard_history(app_name, client):
2122

2223
@pytest.fixture
2324
def semantic_history(app_name, client, hf_vectorizer):
25+
skip_if_no_redisearch(client)
2426
history = SemanticMessageHistory(
2527
app_name, redis_client=client, overwrite=True, vectorizer=hf_vectorizer
2628
)
@@ -326,6 +328,7 @@ def test_standard_clear(standard_history):
326328

327329
# test semantic message history
328330
def test_semantic_specify_client(client, hf_vectorizer):
331+
skip_if_no_redisearch(client)
329332
history = SemanticMessageHistory(
330333
name="test_app",
331334
session_tag="abc",
@@ -616,7 +619,8 @@ def test_semantic_drop(semantic_history):
616619
]
617620

618621

619-
def test_different_vector_dtypes(redis_url):
622+
def test_different_vector_dtypes(client, redis_url):
623+
skip_if_no_redisearch(client)
620624
try:
621625
bfloat_sess = SemanticMessageHistory(
622626
name="bfloat_history", dtype="bfloat16", redis_url=redis_url
@@ -647,6 +651,7 @@ def test_different_vector_dtypes(redis_url):
647651

648652

649653
def test_bad_dtype_connecting_to_exiting_history(client, redis_url):
654+
skip_if_no_redisearch(client)
650655
# Skip this test for Redis 6.2.x as FT.INFO doesn't return dims properly
651656
redis_version = client.info()["redis_version"]
652657
if redis_version.startswith("6.2"):
@@ -674,7 +679,8 @@ def create_same_type():
674679
)
675680

676681

677-
def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16):
682+
def test_vectorizer_dtype_mismatch(client, redis_url, hf_vectorizer_float16):
683+
skip_if_no_redisearch(client)
678684
with pytest.raises(ValueError):
679685
SemanticMessageHistory(
680686
name="test_dtype_mismatch",
@@ -685,7 +691,8 @@ def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16):
685691
)
686692

687693

688-
def test_invalid_vectorizer(redis_url):
694+
def test_invalid_vectorizer(client, redis_url):
695+
skip_if_no_redisearch(client)
689696
with pytest.raises(TypeError):
690697
SemanticMessageHistory(
691698
name="test_invalid_vectorizer",
@@ -695,7 +702,8 @@ def test_invalid_vectorizer(redis_url):
695702
)
696703

697704

698-
def test_passes_through_dtype_to_default_vectorizer(redis_url):
705+
def test_passes_through_dtype_to_default_vectorizer(client, redis_url):
706+
skip_if_no_redisearch(client)
699707
# The default is float32, so we should see float64 if we pass it in.
700708
cache = SemanticMessageHistory(
701709
name="test_pass_through_dtype",
@@ -706,7 +714,8 @@ def test_passes_through_dtype_to_default_vectorizer(redis_url):
706714
assert cache._vectorizer.dtype == "float64"
707715

708716

709-
def test_deprecated_dtype_argument(redis_url):
717+
def test_deprecated_dtype_argument(client, redis_url):
718+
skip_if_no_redisearch(client)
710719
with pytest.warns(DeprecationWarning):
711720
SemanticMessageHistory(
712721
name="float64 history", dtype="float64", redis_url=redis_url, overwrite=True

0 commit comments

Comments
 (0)