Skip to content

Commit 5189c9f

Browse files
committed
feat: add skip_decode parameter to return_fields method (#252)
Implements skip_decode parameter for return_fields() method to improve field deserialization UX. This allows users to skip decoding of binary fields like embeddings while still returning them in query results. - Added optional skip_decode parameter to BaseQuery.return_fields() - Parameter accepts string or list of field names to skip decoding - Maintains backward compatibility when skip_decode is not provided - Comprehensive unit test coverage for all query types - Enhanced skip_decode to use parent's return_field with decode_field=False - Added comprehensive integration tests with real Redis - Maintained full backward compatibility with return_field(decode_field=False) - Tests confirm proper binary field handling (embeddings, image data)
1 parent 27dabc0 commit 5189c9f

File tree

3 files changed

+552
-0
lines changed

3 files changed

+552
-0
lines changed

redisvl/query/query.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def __init__(self, query_string: str = "*"):
4747
# has not been built yet.
4848
self._built_query_string = None
4949

50+
# Initialize skip_decode_fields set
51+
self._skip_decode_fields: Set[str] = set()
52+
5053
def __str__(self) -> str:
5154
"""Return the string representation of the query."""
5255
return " ".join([str(x) for x in self.get_args()])
@@ -107,6 +110,60 @@ def _query_string(self, value: Optional[str]):
107110
"""Setter for _query_string to maintain compatibility with parent class."""
108111
self._built_query_string = value
109112

113+
def return_fields(
114+
self, *fields, skip_decode: Optional[Union[str, List[str]]] = None
115+
):
116+
"""
117+
Set the fields to return with search results.
118+
119+
Args:
120+
*fields: Variable number of field names to return.
121+
skip_decode: Optional field name or list of field names that should not be
122+
decoded. Useful for binary data like embeddings.
123+
124+
Returns:
125+
self: Returns the query object for method chaining.
126+
127+
Raises:
128+
TypeError: If skip_decode is not a string, list, or None.
129+
"""
130+
# Only clear fields when skip_decode is provided (indicating user is explicitly setting fields)
131+
# This preserves backward compatibility when return_fields is called multiple times
132+
if skip_decode is not None:
133+
# Clear existing fields to provide replacement behavior
134+
self._return_fields = []
135+
self._return_fields_decode_as = {}
136+
137+
# Process skip_decode parameter to prepare decode settings
138+
if isinstance(skip_decode, str):
139+
skip_decode_set = {skip_decode}
140+
self._skip_decode_fields = {skip_decode}
141+
elif isinstance(skip_decode, list):
142+
skip_decode_set = set(skip_decode)
143+
self._skip_decode_fields = set(skip_decode)
144+
else:
145+
raise TypeError(
146+
"skip_decode must be a string, list of strings, or None"
147+
)
148+
149+
# Add fields using parent's return_field method with proper decode settings
150+
for field in fields:
151+
if field in skip_decode_set:
152+
# Use return_field with decode_field=False for skip_decode fields
153+
super().return_field(field, decode_field=False)
154+
else:
155+
# Use normal return_field for other fields
156+
super().return_field(field)
157+
else:
158+
# Standard additive behavior (backward compatible)
159+
super().return_fields(*fields)
160+
161+
# Initialize skip_decode_fields if not already set
162+
if not hasattr(self, "_skip_decode_fields"):
163+
self._skip_decode_fields = set()
164+
165+
return self
166+
110167

111168
class FilterQuery(BaseQuery):
112169
def __init__(
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
"""Integration tests for skip_decode parameter in query return_fields (issue #252)."""
2+
3+
import numpy as np
4+
import pytest
5+
from redis import Redis
6+
7+
from redisvl.index import SearchIndex
8+
from redisvl.query import FilterQuery, RangeQuery, VectorQuery
9+
from redisvl.schema import IndexSchema
10+
11+
12+
@pytest.fixture
13+
def sample_schema():
14+
"""Create a sample schema with various field types."""
15+
return IndexSchema.from_dict(
16+
{
17+
"index": {
18+
"name": "test_skip_decode",
19+
"prefix": "doc",
20+
"storage_type": "hash",
21+
},
22+
"fields": [
23+
{"name": "title", "type": "text"},
24+
{"name": "year", "type": "numeric"},
25+
{"name": "description", "type": "text"},
26+
{
27+
"name": "embedding",
28+
"type": "vector",
29+
"attrs": {
30+
"dims": 128,
31+
"algorithm": "flat",
32+
"distance_metric": "cosine",
33+
},
34+
},
35+
{
36+
"name": "image_data",
37+
"type": "tag",
38+
}, # Will store binary data as tag
39+
],
40+
}
41+
)
42+
43+
44+
@pytest.fixture
45+
def search_index(redis_url, sample_schema):
46+
"""Create and populate a test index."""
47+
index = SearchIndex(sample_schema, redis_url=redis_url)
48+
49+
# Clear any existing data
50+
try:
51+
index.delete(drop=True)
52+
except:
53+
pass
54+
55+
# Create the index
56+
index.create(overwrite=True)
57+
58+
# Populate with test data
59+
data = []
60+
for i in range(5):
61+
embedding_vector = np.random.rand(128).astype(np.float32)
62+
doc = {
63+
"title": f"Document {i}",
64+
"year": 2020 + i,
65+
"description": f"This is document number {i}",
66+
"embedding": embedding_vector.tobytes(), # Store as binary
67+
"image_data": f"binary_image_{i}".encode("utf-8"), # Store as binary
68+
}
69+
data.append(doc)
70+
71+
# Load data into Redis
72+
index.load(data, id_field="title")
73+
74+
yield index
75+
76+
# Cleanup
77+
try:
78+
index.delete(drop=True)
79+
except:
80+
pass
81+
82+
83+
class TestSkipDecodeIntegration:
84+
"""Integration tests for skip_decode functionality with real Redis."""
85+
86+
def test_filter_query_skip_decode_single_field(self, search_index):
87+
"""Test FilterQuery with skip_decode for embedding field."""
88+
query = FilterQuery(num_results=10)
89+
query.return_fields("title", "year", "embedding", skip_decode=["embedding"])
90+
91+
results = search_index.query(query)
92+
93+
# Verify we got results
94+
assert len(results) > 0
95+
96+
# Check first result
97+
first_result = results[0]
98+
assert "title" in first_result
99+
assert "year" in first_result
100+
assert "embedding" in first_result
101+
102+
# Title and year should be decoded strings
103+
assert isinstance(first_result["title"], str)
104+
assert isinstance(first_result["year"], str) # Redis returns as string
105+
106+
# Embedding should remain as bytes (not decoded)
107+
assert isinstance(first_result["embedding"], bytes)
108+
109+
def test_filter_query_skip_decode_multiple_fields(self, search_index):
110+
"""Test FilterQuery with skip_decode for multiple binary fields."""
111+
query = FilterQuery(num_results=10)
112+
query.return_fields(
113+
"title",
114+
"year",
115+
"embedding",
116+
"image_data",
117+
skip_decode=["embedding", "image_data"],
118+
)
119+
120+
results = search_index.query(query)
121+
122+
assert len(results) > 0
123+
124+
first_result = results[0]
125+
# Decoded fields
126+
assert isinstance(first_result["title"], str)
127+
assert isinstance(first_result["year"], str)
128+
129+
# Non-decoded fields (should be bytes)
130+
assert isinstance(first_result["embedding"], bytes)
131+
assert isinstance(first_result["image_data"], bytes)
132+
133+
def test_filter_query_no_skip_decode_default(self, search_index):
134+
"""Test FilterQuery without skip_decode (default behavior)."""
135+
query = FilterQuery(num_results=10)
136+
query.return_fields("title", "year", "description")
137+
138+
results = search_index.query(query)
139+
140+
assert len(results) > 0
141+
142+
first_result = results[0]
143+
# All fields should be decoded to strings
144+
assert isinstance(first_result["title"], str)
145+
assert isinstance(first_result["year"], str)
146+
assert isinstance(first_result["description"], str)
147+
148+
def test_vector_query_skip_decode(self, search_index):
149+
"""Test VectorQuery with skip_decode for embedding field."""
150+
# Create a random query vector
151+
query_vector = np.random.rand(128).astype(np.float32)
152+
153+
query = VectorQuery(
154+
vector=query_vector.tolist(),
155+
vector_field_name="embedding",
156+
return_fields=None, # Will set with method
157+
num_results=3,
158+
return_score=True, # Explicitly request distance score
159+
)
160+
161+
# Use skip_decode for embedding
162+
query.return_fields("title", "embedding", skip_decode=["embedding"])
163+
164+
results = search_index.query(query)
165+
166+
assert len(results) > 0
167+
168+
for result in results:
169+
assert isinstance(result["title"], str)
170+
# Embedding should be bytes (not decoded)
171+
assert isinstance(result["embedding"], bytes)
172+
# Distance score is added automatically by VectorQuery when return_score=True
173+
# but may not be in the result dict, just check the fields we requested
174+
175+
def test_range_query_skip_decode(self, search_index):
176+
"""Test RangeQuery with skip_decode for binary fields."""
177+
# Create a random query vector
178+
query_vector = np.random.rand(128).astype(np.float32)
179+
180+
query = RangeQuery(
181+
vector=query_vector.tolist(),
182+
vector_field_name="embedding",
183+
distance_threshold=1.0,
184+
return_fields=None,
185+
num_results=10,
186+
)
187+
188+
query.return_fields("title", "year", "embedding", skip_decode=["embedding"])
189+
190+
results = search_index.query(query)
191+
192+
if len(results) > 0: # Range query might not return results
193+
first_result = results[0]
194+
assert isinstance(first_result["title"], str)
195+
assert isinstance(first_result["year"], str)
196+
assert isinstance(first_result["embedding"], bytes)
197+
198+
def test_backward_compat_return_field_decode_false(self, search_index):
199+
"""Test backward compatibility with return_field(decode_field=False)."""
200+
query = FilterQuery(num_results=10)
201+
202+
# Use old API - return_field with decode_field=False
203+
query.return_field("embedding", decode_field=False)
204+
query.return_field("image_data", decode_field=False)
205+
query.return_fields("title", "year") # These should be decoded
206+
207+
results = search_index.query(query)
208+
209+
assert len(results) > 0
210+
211+
first_result = results[0]
212+
# Decoded fields
213+
assert isinstance(first_result["title"], str)
214+
assert isinstance(first_result["year"], str)
215+
216+
# Non-decoded fields (using old API)
217+
assert isinstance(first_result["embedding"], bytes)
218+
assert isinstance(first_result["image_data"], bytes)
219+
220+
def test_mixed_api_usage(self, search_index):
221+
"""Test mixing old and new API calls."""
222+
query = FilterQuery(num_results=10)
223+
224+
# First use old API
225+
query.return_field("image_data", decode_field=False)
226+
227+
# Then use new API with skip_decode
228+
query.return_fields("title", "year", "embedding", skip_decode=["embedding"])
229+
230+
results = search_index.query(query)
231+
232+
assert len(results) > 0
233+
234+
first_result = results[0]
235+
# The new API call should have replaced everything
236+
# (when skip_decode is provided, it clears previous fields)
237+
assert "title" in first_result
238+
assert "year" in first_result
239+
assert "embedding" in first_result
240+
241+
# image_data should not be in results since return_fields
242+
# with skip_decode clears previous fields
243+
assert "image_data" not in first_result
244+
245+
def test_skip_decode_with_empty_list(self, search_index):
246+
"""Test skip_decode with empty list (all fields decoded)."""
247+
query = FilterQuery(num_results=10)
248+
query.return_fields("title", "year", "description", skip_decode=[])
249+
250+
results = search_index.query(query)
251+
252+
assert len(results) > 0
253+
254+
first_result = results[0]
255+
# All fields should be decoded
256+
assert isinstance(first_result["title"], str)
257+
assert isinstance(first_result["year"], str)
258+
assert isinstance(first_result["description"], str)
259+
260+
def test_skip_decode_with_string_parameter(self, search_index):
261+
"""Test skip_decode accepts a single string instead of list."""
262+
query = FilterQuery(num_results=10)
263+
264+
# Pass a single string instead of list
265+
query.return_fields("title", "embedding", skip_decode="embedding")
266+
267+
results = search_index.query(query)
268+
269+
assert len(results) > 0
270+
271+
first_result = results[0]
272+
assert isinstance(first_result["title"], str)
273+
# Embedding should be bytes (not decoded)
274+
assert isinstance(first_result["embedding"], bytes)
275+
276+
def test_multiple_calls_without_skip_decode(self, search_index):
277+
"""Test multiple return_fields calls without skip_decode (additive behavior)."""
278+
query = FilterQuery(num_results=10)
279+
280+
# Multiple calls without skip_decode should be additive
281+
query.return_fields("title")
282+
query.return_fields("year")
283+
query.return_field("embedding", decode_field=False)
284+
285+
results = search_index.query(query)
286+
287+
assert len(results) > 0
288+
289+
first_result = results[0]
290+
# All fields should be present (additive behavior)
291+
assert "title" in first_result
292+
assert "year" in first_result
293+
assert "embedding" in first_result
294+
295+
# Check types
296+
assert isinstance(first_result["title"], str)
297+
assert isinstance(first_result["year"], str)
298+
assert isinstance(first_result["embedding"], bytes)
299+
300+
def test_replacement_behavior_with_skip_decode(self, search_index):
301+
"""Test that skip_decode parameter triggers replacement behavior."""
302+
query = FilterQuery(num_results=10)
303+
304+
# First set some fields
305+
query.return_fields("title", "description")
306+
307+
# Then call with skip_decode - should replace, not add
308+
query.return_fields("year", "embedding", skip_decode=["embedding"])
309+
310+
results = search_index.query(query)
311+
312+
assert len(results) > 0
313+
314+
first_result = results[0]
315+
# Only fields from second call should be present
316+
assert "year" in first_result
317+
assert "embedding" in first_result
318+
319+
# Fields from first call should NOT be present (replaced)
320+
assert "title" not in first_result
321+
assert "description" not in first_result
322+
323+
# Check embedding is not decoded
324+
assert isinstance(first_result["embedding"], bytes)

0 commit comments

Comments
 (0)