Skip to content

Commit 3d45afd

Browse files
committed
feat: add text field weights support to TextQuery (#360)
Adds the ability to specify weights for text fields in RedisVL queries, enabling users to prioritize certain fields over others in search results. - Support dictionary of field:weight mappings in TextQuery constructor - Maintain backward compatibility with single string field names - Add set_field_weights() method for dynamic weight updates - Generate proper Redis query syntax with weight modifiers - Comprehensive validation for positive numeric weights Example usage: ```python # Single field with weight query = TextQuery(text="search", text_field_name={"title": 5.0}) # Multiple fields with weights query = TextQuery( text="search", text_field_name={"title": 3.0, "content": 1.5, "tags": 1.0} ) ```
1 parent a56f9a1 commit 3d45afd

File tree

3 files changed

+398
-6
lines changed

3 files changed

+398
-6
lines changed

redisvl/query/query.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ class TextQuery(BaseQuery):
801801
def __init__(
802802
self,
803803
text: str,
804-
text_field_name: str,
804+
text_field_name: Union[str, Dict[str, float]],
805805
text_scorer: str = "BM25STD",
806806
filter_expression: Optional[Union[str, FilterExpression]] = None,
807807
return_fields: Optional[List[str]] = None,
@@ -817,7 +817,8 @@ def __init__(
817817
818818
Args:
819819
text (str): The text string to perform the text search with.
820-
text_field_name (str): The name of the document field to perform text search on.
820+
text_field_name (Union[str, Dict[str, float]]): The name of the document field to perform
821+
text search on, or a dictionary mapping field names to their weights.
821822
text_scorer (str, optional): The text scoring algorithm to use.
822823
Defaults to BM25STD. Options are {TFIDF, BM25STD, BM25, TFIDF.DOCNORM, DISMAX, DOCSCORE}.
823824
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
@@ -849,7 +850,7 @@ def __init__(
849850
TypeError: If stopwords is not a valid iterable set of strings.
850851
"""
851852
self._text = text
852-
self._text_field_name = text_field_name
853+
self._field_weights = self._parse_field_weights(text_field_name)
853854
self._num_results = num_results
854855

855856
self._set_stopwords(stopwords)
@@ -934,15 +935,97 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
934935
[token for token in tokens if token and token not in self._stopwords]
935936
)
936937

938+
def _parse_field_weights(
939+
self, field_spec: Union[str, Dict[str, float]]
940+
) -> Dict[str, float]:
941+
"""Parse the field specification into a weights dictionary.
942+
943+
Args:
944+
field_spec: Either a single field name or dictionary of field:weight mappings
945+
946+
Returns:
947+
Dictionary mapping field names to their weights
948+
"""
949+
if isinstance(field_spec, str):
950+
return {field_spec: 1.0}
951+
elif isinstance(field_spec, dict):
952+
# Validate all weights are numeric and positive
953+
for field, weight in field_spec.items():
954+
if not isinstance(field, str):
955+
raise TypeError(f"Field name must be a string, got {type(field)}")
956+
if not isinstance(weight, (int, float)):
957+
raise TypeError(
958+
f"Weight for field '{field}' must be numeric, got {type(weight)}"
959+
)
960+
if weight <= 0:
961+
raise ValueError(
962+
f"Weight for field '{field}' must be positive, got {weight}"
963+
)
964+
return field_spec
965+
else:
966+
raise TypeError(
967+
"text_field_name must be a string or dictionary of field:weight mappings"
968+
)
969+
970+
def set_field_weights(self, field_weights: Union[str, Dict[str, float]]):
971+
"""Set or update the field weights for the query.
972+
973+
Args:
974+
field_weights: Either a single field name or dictionary of field:weight mappings
975+
"""
976+
self._field_weights = self._parse_field_weights(field_weights)
977+
# Invalidate the query string
978+
self._built_query_string = None
979+
980+
@property
981+
def field_weights(self) -> Dict[str, float]:
982+
"""Get the field weights for the query.
983+
984+
Returns:
985+
Dictionary mapping field names to their weights
986+
"""
987+
return self._field_weights.copy()
988+
989+
@property
990+
def text_field_name(self) -> Union[str, Dict[str, float]]:
991+
"""Get the text field name(s) - for backward compatibility.
992+
993+
Returns:
994+
Either a single field name string (if only one field with weight 1.0)
995+
or a dictionary of field:weight mappings.
996+
"""
997+
if len(self._field_weights) == 1:
998+
field, weight = next(iter(self._field_weights.items()))
999+
if weight == 1.0:
1000+
return field
1001+
return self._field_weights.copy()
1002+
9371003
def _build_query_string(self) -> str:
9381004
"""Build the full query string for text search with optional filtering."""
9391005
filter_expression = self._filter_expression
9401006
if isinstance(filter_expression, FilterExpression):
9411007
filter_expression = str(filter_expression)
9421008

943-
text = (
944-
f"@{self._text_field_name}:({self._tokenize_and_escape_query(self._text)})"
945-
)
1009+
escaped_query = self._tokenize_and_escape_query(self._text)
1010+
1011+
# Build query parts for each field with its weight
1012+
field_queries = []
1013+
for field, weight in self._field_weights.items():
1014+
if weight == 1.0:
1015+
# Default weight doesn't need explicit weight syntax
1016+
field_queries.append(f"@{field}:({escaped_query})")
1017+
else:
1018+
# Use Redis weight syntax for non-default weights
1019+
field_queries.append(
1020+
f"@{field}:({escaped_query}) => {{ $weight: {weight} }}"
1021+
)
1022+
1023+
# Join multiple field queries with OR operator
1024+
if len(field_queries) == 1:
1025+
text = field_queries[0]
1026+
else:
1027+
text = "(" + " | ".join(field_queries) + ")"
1028+
9461029
if filter_expression and filter_expression != "*":
9471030
text += f" AND {filter_expression}"
9481031
return text
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
"""Integration tests for TextQuery with field weights."""
2+
3+
import uuid
4+
5+
import pytest
6+
7+
from redisvl.index import SearchIndex
8+
from redisvl.query import TextQuery
9+
from redisvl.query.filter import Tag
10+
11+
12+
@pytest.fixture
13+
def weighted_index(redis_url, worker_id):
14+
"""Create an index with multiple text fields for testing weights."""
15+
unique_id = str(uuid.uuid4())[:8]
16+
schema_dict = {
17+
"index": {
18+
"name": f"weighted_test_idx_{worker_id}_{unique_id}",
19+
"prefix": f"weighted_doc_{worker_id}_{unique_id}",
20+
"storage_type": "json",
21+
},
22+
"fields": [
23+
{"name": "title", "type": "text"},
24+
{"name": "content", "type": "text"},
25+
{"name": "tags", "type": "text"},
26+
{"name": "category", "type": "tag"},
27+
{"name": "score", "type": "numeric"},
28+
],
29+
}
30+
31+
index = SearchIndex.from_dict(schema_dict, redis_url=redis_url)
32+
index.create(overwrite=True)
33+
34+
# Load test data
35+
data = [
36+
{
37+
"id": "1",
38+
"title": "Redis database introduction",
39+
"content": "A comprehensive guide to getting started with Redis",
40+
"tags": "tutorial beginner",
41+
"category": "database",
42+
"score": 95,
43+
},
44+
{
45+
"id": "2",
46+
"title": "Advanced caching strategies",
47+
"content": "Learn about Redis caching patterns and best practices",
48+
"tags": "redis cache performance",
49+
"category": "optimization",
50+
"score": 88,
51+
},
52+
{
53+
"id": "3",
54+
"title": "Python programming basics",
55+
"content": "Introduction to Python with examples using Redis client",
56+
"tags": "python redis programming",
57+
"category": "programming",
58+
"score": 90,
59+
},
60+
{
61+
"id": "4",
62+
"title": "Data structures overview",
63+
"content": "Understanding Redis data structures and their applications",
64+
"tags": "redis structures",
65+
"category": "database",
66+
"score": 85,
67+
},
68+
]
69+
70+
index.load(data)
71+
yield index
72+
index.delete(drop=True)
73+
74+
75+
def test_text_query_with_single_weighted_field(weighted_index):
76+
"""Test TextQuery with a single weighted field."""
77+
text = "redis"
78+
79+
# Query with higher weight on title
80+
query = TextQuery(
81+
text=text,
82+
text_field_name={"title": 5.0},
83+
return_fields=["title", "content"],
84+
num_results=4,
85+
)
86+
87+
results = weighted_index.query(query)
88+
assert len(results) > 0
89+
90+
# The document with "Redis" in the title should rank high
91+
top_result = results[0]
92+
assert "redis" in top_result["title"].lower()
93+
94+
95+
def test_text_query_with_multiple_weighted_fields(weighted_index):
96+
"""Test TextQuery with multiple weighted fields."""
97+
text = "redis"
98+
99+
# Query across multiple fields with different weights
100+
query = TextQuery(
101+
text=text,
102+
text_field_name={"title": 3.0, "content": 2.0, "tags": 1.0},
103+
return_fields=["title", "content", "tags"],
104+
num_results=4,
105+
)
106+
107+
results = weighted_index.query(query)
108+
assert len(results) > 0
109+
110+
# Check that results contain the search term in at least one field
111+
for result in results:
112+
text_found = (
113+
"redis" in result.get("title", "").lower()
114+
or "redis" in result.get("content", "").lower()
115+
or "redis" in result.get("tags", "").lower()
116+
)
117+
assert text_found
118+
119+
120+
def test_text_query_weights_with_filter(weighted_index):
121+
"""Test TextQuery with weights and filter expression."""
122+
text = "redis"
123+
124+
# Query with weights and filter
125+
filter_expr = Tag("category") == "database"
126+
query = TextQuery(
127+
text=text,
128+
text_field_name={"title": 5.0, "content": 1.0},
129+
filter_expression=filter_expr,
130+
return_fields=["title", "content", "category"],
131+
num_results=4,
132+
)
133+
134+
results = weighted_index.query(query)
135+
136+
# Should only get database category results
137+
for result in results:
138+
assert result["category"] == "database"
139+
140+
141+
def test_dynamic_weight_update(weighted_index):
142+
"""Test updating field weights dynamically."""
143+
text = "redis"
144+
145+
# Start with equal weights
146+
query = TextQuery(
147+
text=text,
148+
text_field_name={"title": 1.0, "content": 1.0},
149+
return_fields=["title", "content"],
150+
num_results=4,
151+
)
152+
153+
results1 = weighted_index.query(query)
154+
155+
# Update to prioritize title
156+
query.set_field_weights({"title": 10.0, "content": 1.0})
157+
158+
results2 = weighted_index.query(query)
159+
160+
# Results might be reordered based on new weights
161+
# At minimum, both queries should return results
162+
assert len(results1) > 0
163+
assert len(results2) > 0
164+
165+
166+
def test_backward_compatibility_single_field(weighted_index):
167+
"""Test that the original single field API still works."""
168+
text = "redis"
169+
170+
# Original API with single field name
171+
query = TextQuery(
172+
text=text,
173+
text_field_name="content",
174+
return_fields=["title", "content"],
175+
num_results=4,
176+
)
177+
178+
results = weighted_index.query(query)
179+
assert len(results) > 0
180+
181+
# Check results are from content field
182+
for result in results:
183+
if "redis" in result.get("content", "").lower():
184+
break
185+
else:
186+
# At least one result should have redis in content
187+
assert False, "No results with 'redis' in content field"

0 commit comments

Comments
 (0)