-
Notifications
You must be signed in to change notification settings - Fork 60
feat: add skip_decode parameter to return_fields method (#252) #389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+551
−0
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,9 @@ def __init__(self, query_string: str = "*"): | |
| # has not been built yet. | ||
| self._built_query_string = None | ||
|
|
||
| # Initialize skip_decode_fields set | ||
| self._skip_decode_fields: Set[str] = set() | ||
|
Comment on lines
+50
to
+51
|
||
|
|
||
| def __str__(self) -> str: | ||
| """Return the string representation of the query.""" | ||
| return " ".join([str(x) for x in self.get_args()]) | ||
|
|
@@ -107,6 +110,58 @@ def _query_string(self, value: Optional[str]): | |
| """Setter for _query_string to maintain compatibility with parent class.""" | ||
| self._built_query_string = value | ||
|
|
||
| def return_fields( | ||
| self, *fields, skip_decode: Optional[Union[str, List[str]]] = None | ||
| ): | ||
| """ | ||
| Set the fields to return with search results. | ||
|
|
||
| Args: | ||
| *fields: Variable number of field names to return. | ||
| skip_decode: Optional field name or list of field names that should not be | ||
| decoded. Useful for binary data like embeddings. | ||
|
|
||
| Returns: | ||
| self: Returns the query object for method chaining. | ||
|
|
||
| Raises: | ||
| TypeError: If skip_decode is not a string, list, or None. | ||
| """ | ||
| # Only clear fields when skip_decode is provided (indicating user is explicitly setting fields) | ||
| # This preserves backward compatibility when return_fields is called multiple times | ||
| if skip_decode is not None: | ||
| # Clear existing fields to provide replacement behavior | ||
| self._return_fields = [] | ||
| self._return_fields_decode_as = {} | ||
|
|
||
| # Process skip_decode parameter to prepare decode settings | ||
| if isinstance(skip_decode, str): | ||
| skip_decode_set = {skip_decode} | ||
| self._skip_decode_fields = {skip_decode} | ||
| elif isinstance(skip_decode, list): | ||
| skip_decode_set = set(skip_decode) | ||
| self._skip_decode_fields = set(skip_decode) | ||
| else: | ||
| raise TypeError("skip_decode must be a string or list of strings") | ||
|
|
||
| # Add fields using parent's return_field method with proper decode settings | ||
| for field in fields: | ||
| if field in skip_decode_set: | ||
| # Use return_field with decode_field=False for skip_decode fields | ||
| super().return_field(field, decode_field=False) | ||
| else: | ||
| # Use normal return_field for other fields | ||
| super().return_field(field) | ||
| else: | ||
| # Standard additive behavior (backward compatible) | ||
| super().return_fields(*fields) | ||
|
|
||
| # Initialize skip_decode_fields if not already set | ||
| if not hasattr(self, "_skip_decode_fields"): | ||
| self._skip_decode_fields = set() | ||
|
|
||
| return self | ||
|
|
||
|
|
||
| class FilterQuery(BaseQuery): | ||
| def __init__( | ||
|
|
||
325 changes: 325 additions & 0 deletions
325
tests/integration/test_skip_decode_fields_integration.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,325 @@ | ||
| """Integration tests for skip_decode parameter in query return_fields (issue #252).""" | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| from redis import Redis | ||
|
|
||
| from redisvl.exceptions import RedisSearchError | ||
| from redisvl.index import SearchIndex | ||
| from redisvl.query import FilterQuery, RangeQuery, VectorQuery | ||
| from redisvl.schema import IndexSchema | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sample_schema(): | ||
| """Create a sample schema with various field types.""" | ||
| return IndexSchema.from_dict( | ||
| { | ||
| "index": { | ||
| "name": "test_skip_decode", | ||
| "prefix": "doc", | ||
| "storage_type": "hash", | ||
| }, | ||
| "fields": [ | ||
| {"name": "title", "type": "text"}, | ||
| {"name": "year", "type": "numeric"}, | ||
| {"name": "description", "type": "text"}, | ||
| { | ||
| "name": "embedding", | ||
| "type": "vector", | ||
| "attrs": { | ||
| "dims": 128, | ||
| "algorithm": "flat", | ||
| "distance_metric": "cosine", | ||
| }, | ||
| }, | ||
| { | ||
| "name": "image_data", | ||
| "type": "tag", | ||
| }, # Will store binary data as tag | ||
| ], | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def search_index(redis_url, sample_schema): | ||
| """Create and populate a test index.""" | ||
| index = SearchIndex(sample_schema, redis_url=redis_url) | ||
|
|
||
| # Clear any existing data | ||
| try: | ||
| index.delete(drop=True) | ||
| except RedisSearchError: | ||
| pass # Index may not exist, which is fine | ||
|
|
||
| # Create the index | ||
| index.create(overwrite=True) | ||
|
|
||
| # Populate with test data | ||
| data = [] | ||
| for i in range(5): | ||
| embedding_vector = np.random.rand(128).astype(np.float32) | ||
| doc = { | ||
| "title": f"Document {i}", | ||
| "year": 2020 + i, | ||
| "description": f"This is document number {i}", | ||
| "embedding": embedding_vector.tobytes(), # Store as binary | ||
| "image_data": f"binary_image_{i}".encode("utf-8"), # Store as binary | ||
| } | ||
| data.append(doc) | ||
|
|
||
| # Load data into Redis | ||
| index.load(data, id_field="title") | ||
|
|
||
| yield index | ||
|
|
||
| # Cleanup | ||
| try: | ||
| index.delete(drop=True) | ||
| except RedisSearchError: | ||
| pass # Ignore cleanup errors | ||
|
|
||
|
|
||
| class TestSkipDecodeIntegration: | ||
| """Integration tests for skip_decode functionality with real Redis.""" | ||
|
|
||
| def test_filter_query_skip_decode_single_field(self, search_index): | ||
| """Test FilterQuery with skip_decode for embedding field.""" | ||
| query = FilterQuery(num_results=10) | ||
| query.return_fields("title", "year", "embedding", skip_decode=["embedding"]) | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| # Verify we got results | ||
| assert len(results) > 0 | ||
|
|
||
| # Check first result | ||
| first_result = results[0] | ||
| assert "title" in first_result | ||
| assert "year" in first_result | ||
| assert "embedding" in first_result | ||
|
|
||
| # Title and year should be decoded strings | ||
| assert isinstance(first_result["title"], str) | ||
| assert isinstance(first_result["year"], str) # Redis returns as string | ||
|
|
||
| # Embedding should remain as bytes (not decoded) | ||
| assert isinstance(first_result["embedding"], bytes) | ||
|
|
||
| def test_filter_query_skip_decode_multiple_fields(self, search_index): | ||
| """Test FilterQuery with skip_decode for multiple binary fields.""" | ||
| query = FilterQuery(num_results=10) | ||
| query.return_fields( | ||
| "title", | ||
| "year", | ||
| "embedding", | ||
| "image_data", | ||
| skip_decode=["embedding", "image_data"], | ||
| ) | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| assert len(results) > 0 | ||
|
|
||
| first_result = results[0] | ||
| # Decoded fields | ||
| assert isinstance(first_result["title"], str) | ||
| assert isinstance(first_result["year"], str) | ||
|
|
||
| # Non-decoded fields (should be bytes) | ||
| assert isinstance(first_result["embedding"], bytes) | ||
| assert isinstance(first_result["image_data"], bytes) | ||
|
|
||
| def test_filter_query_no_skip_decode_default(self, search_index): | ||
| """Test FilterQuery without skip_decode (default behavior).""" | ||
| query = FilterQuery(num_results=10) | ||
| query.return_fields("title", "year", "description") | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| assert len(results) > 0 | ||
|
|
||
| first_result = results[0] | ||
| # All fields should be decoded to strings | ||
| assert isinstance(first_result["title"], str) | ||
| assert isinstance(first_result["year"], str) | ||
| assert isinstance(first_result["description"], str) | ||
|
|
||
| def test_vector_query_skip_decode(self, search_index): | ||
| """Test VectorQuery with skip_decode for embedding field.""" | ||
| # Create a random query vector | ||
| query_vector = np.random.rand(128).astype(np.float32) | ||
|
|
||
| query = VectorQuery( | ||
| vector=query_vector.tolist(), | ||
| vector_field_name="embedding", | ||
| return_fields=None, # Will set with method | ||
| num_results=3, | ||
| return_score=True, # Explicitly request distance score | ||
| ) | ||
|
|
||
| # Use skip_decode for embedding | ||
| query.return_fields("title", "embedding", skip_decode=["embedding"]) | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| assert len(results) > 0 | ||
|
|
||
| for result in results: | ||
| assert isinstance(result["title"], str) | ||
| # Embedding should be bytes (not decoded) | ||
| assert isinstance(result["embedding"], bytes) | ||
| # Distance score is added automatically by VectorQuery when return_score=True | ||
| # but may not be in the result dict, just check the fields we requested | ||
|
|
||
| def test_range_query_skip_decode(self, search_index): | ||
| """Test RangeQuery with skip_decode for binary fields.""" | ||
| # Create a random query vector | ||
| query_vector = np.random.rand(128).astype(np.float32) | ||
|
|
||
| query = RangeQuery( | ||
| vector=query_vector.tolist(), | ||
| vector_field_name="embedding", | ||
| distance_threshold=1.0, | ||
| return_fields=None, | ||
| num_results=10, | ||
| ) | ||
|
|
||
| query.return_fields("title", "year", "embedding", skip_decode=["embedding"]) | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| if len(results) > 0: # Range query might not return results | ||
| first_result = results[0] | ||
| assert isinstance(first_result["title"], str) | ||
| assert isinstance(first_result["year"], str) | ||
| assert isinstance(first_result["embedding"], bytes) | ||
|
|
||
| def test_backward_compat_return_field_decode_false(self, search_index): | ||
| """Test backward compatibility with return_field(decode_field=False).""" | ||
| query = FilterQuery(num_results=10) | ||
|
|
||
| # Use old API - return_field with decode_field=False | ||
| query.return_field("embedding", decode_field=False) | ||
| query.return_field("image_data", decode_field=False) | ||
| query.return_fields("title", "year") # These should be decoded | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| assert len(results) > 0 | ||
|
|
||
| first_result = results[0] | ||
| # Decoded fields | ||
| assert isinstance(first_result["title"], str) | ||
| assert isinstance(first_result["year"], str) | ||
|
|
||
| # Non-decoded fields (using old API) | ||
| assert isinstance(first_result["embedding"], bytes) | ||
| assert isinstance(first_result["image_data"], bytes) | ||
|
|
||
| def test_mixed_api_usage(self, search_index): | ||
| """Test mixing old and new API calls.""" | ||
| query = FilterQuery(num_results=10) | ||
|
|
||
| # First use old API | ||
| query.return_field("image_data", decode_field=False) | ||
|
|
||
| # Then use new API with skip_decode | ||
| query.return_fields("title", "year", "embedding", skip_decode=["embedding"]) | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| assert len(results) > 0 | ||
|
|
||
| first_result = results[0] | ||
| # The new API call should have replaced everything | ||
| # (when skip_decode is provided, it clears previous fields) | ||
| assert "title" in first_result | ||
| assert "year" in first_result | ||
| assert "embedding" in first_result | ||
|
|
||
| # image_data should not be in results since return_fields | ||
| # with skip_decode clears previous fields | ||
| assert "image_data" not in first_result | ||
|
|
||
| def test_skip_decode_with_empty_list(self, search_index): | ||
| """Test skip_decode with empty list (all fields decoded).""" | ||
| query = FilterQuery(num_results=10) | ||
| query.return_fields("title", "year", "description", skip_decode=[]) | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| assert len(results) > 0 | ||
|
|
||
| first_result = results[0] | ||
| # All fields should be decoded | ||
| assert isinstance(first_result["title"], str) | ||
| assert isinstance(first_result["year"], str) | ||
| assert isinstance(first_result["description"], str) | ||
|
|
||
| def test_skip_decode_with_string_parameter(self, search_index): | ||
| """Test skip_decode accepts a single string instead of list.""" | ||
| query = FilterQuery(num_results=10) | ||
|
|
||
| # Pass a single string instead of list | ||
| query.return_fields("title", "embedding", skip_decode="embedding") | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| assert len(results) > 0 | ||
|
|
||
| first_result = results[0] | ||
| assert isinstance(first_result["title"], str) | ||
| # Embedding should be bytes (not decoded) | ||
| assert isinstance(first_result["embedding"], bytes) | ||
|
|
||
| def test_multiple_calls_without_skip_decode(self, search_index): | ||
| """Test multiple return_fields calls without skip_decode (additive behavior).""" | ||
| query = FilterQuery(num_results=10) | ||
|
|
||
| # Multiple calls without skip_decode should be additive | ||
| query.return_fields("title") | ||
| query.return_fields("year") | ||
| query.return_field("embedding", decode_field=False) | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| assert len(results) > 0 | ||
|
|
||
| first_result = results[0] | ||
| # All fields should be present (additive behavior) | ||
| assert "title" in first_result | ||
| assert "year" in first_result | ||
| assert "embedding" in first_result | ||
|
|
||
| # Check types | ||
| assert isinstance(first_result["title"], str) | ||
| assert isinstance(first_result["year"], str) | ||
| assert isinstance(first_result["embedding"], bytes) | ||
|
|
||
| def test_replacement_behavior_with_skip_decode(self, search_index): | ||
| """Test that skip_decode parameter triggers replacement behavior.""" | ||
| query = FilterQuery(num_results=10) | ||
|
|
||
| # First set some fields | ||
| query.return_fields("title", "description") | ||
|
|
||
| # Then call with skip_decode - should replace, not add | ||
| query.return_fields("year", "embedding", skip_decode=["embedding"]) | ||
|
|
||
| results = search_index.query(query) | ||
|
|
||
| assert len(results) > 0 | ||
|
|
||
| first_result = results[0] | ||
| # Only fields from second call should be present | ||
| assert "year" in first_result | ||
| assert "embedding" in first_result | ||
|
|
||
| # Fields from first call should NOT be present (replaced) | ||
| assert "title" not in first_result | ||
| assert "description" not in first_result | ||
|
|
||
| # Check embedding is not decoded | ||
| assert isinstance(first_result["embedding"], bytes) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
Settype is used without importing it. This will cause aNameErrorat runtime. Either importSetfromtypingor use the built-insettype annotation.