Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode search results at field level #3309

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Decode search results at field level
Fixes: #2772, #2275
  • Loading branch information
uglide committed Jul 9, 2024
commit 7ca2f2914080bca067e8bc87d52e9f506a1c7df9
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ urllib3<2
uvloop
vulture>=2.3.0
wheel>=0.30.0
numpy>=1.24.0
5 changes: 3 additions & 2 deletions docs/examples/search_vector_similarity_examples.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions redis/commands/search/_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def to_string(s):
def to_string(s, encoding: str = "utf-8"):
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode("utf-8", "ignore")
return s.decode(encoding, "ignore")
else:
return s # Not a string we care about
1 change: 1 addition & 0 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _parse_search(self, res, **kwargs):
duration=kwargs["duration"],
has_payload=kwargs["query"]._with_payloads,
with_scores=kwargs["query"]._with_scores,
field_encodings=kwargs["query"]._return_fields_decode_as,
)

def _parse_aggregate(self, res, **kwargs):
Expand Down
23 changes: 19 additions & 4 deletions redis/commands/search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, query_string: str) -> None:
self._in_order: bool = False
self._sortby: Optional[SortbyField] = None
self._return_fields: List = []
self._return_fields_decode_as: dict = {}
self._summarize_fields: List = []
self._highlight_fields: List = []
self._language: Optional[str] = None
Expand All @@ -53,13 +54,27 @@ def limit_ids(self, *ids) -> "Query":

def return_fields(self, *fields) -> "Query":
"""Add fields to return fields."""
self._return_fields += fields
for field in fields:
self.return_field(field)
return self

def return_field(self, field: str, as_field: Optional[str] = None) -> "Query":
"""Add field to return fields (Optional: add 'AS' name
to the field)."""
def return_field(
self,
field: str,
as_field: Optional[str] = None,
decode_field: Optional[bool] = True,
encoding: Optional[str] = "utf8",
) -> "Query":
"""
Add a field to the list of fields to return.

- **field**: The field to include in query results
- **as_field**: The alias for the field
- **decode_field**: Whether to decode the field from bytes to string
- **encoding**: The encoding to use when decoding the field
"""
self._return_fields.append(field)
self._return_fields_decode_as[field] = encoding if decode_field else None
if as_field is not None:
self._return_fields += ("AS", as_field)
return self
Expand Down
44 changes: 29 additions & 15 deletions redis/commands/search/result.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from ._util import to_string
from .document import Document

Expand All @@ -9,11 +11,19 @@ class Result:
"""

def __init__(
self, res, hascontent, duration=0, has_payload=False, with_scores=False
self,
res,
hascontent,
duration=0,
has_payload=False,
with_scores=False,
field_encodings: Optional[dict] = None,
):
"""
- **snippets**: An optional dictionary of the form
{field: snippet_size} for snippet formatting
- duration: the execution time of the query
- has_payload: whether the query has payloads
- with_scores: whether the query has scores
- field_encodings: a dictionary of field encodings if any is provided
"""

self.total = res[0]
Expand All @@ -39,18 +49,22 @@ def __init__(

fields = {}
if hascontent and res[i + fields_offset] is not None:
fields = (
dict(
dict(
zip(
map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]),
)
)
)
if hascontent
else {}
)
keys = map(to_string, res[i + fields_offset][::2])
values = res[i + fields_offset][1::2]

for key, value in zip(keys, values):
if field_encodings is None or key not in field_encodings:
fields[key] = to_string(value)
gerzse marked this conversation as resolved.
Show resolved Hide resolved
continue

encoding = field_encodings[key]

# If the encoding is None, we don't need to decode the value
if encoding is None:
fields[key] = value
else:
fields[key] = to_string(value, encoding=encoding)

try:
del fields["id"]
except KeyError:
Expand Down
63 changes: 63 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from io import TextIOWrapper

import numpy as np
import pytest
import redis
import redis.commands.search
Expand Down Expand Up @@ -113,6 +114,13 @@ def client(request, stack_url):
return r


@pytest.fixture
def binary_client(request, stack_url):
r = _get_client(redis.Redis, request, decode_responses=False, from_url=stack_url)
r.flushdb()
return r


@pytest.mark.redismod
def test_client(client):
num_docs = 500
Expand Down Expand Up @@ -1705,6 +1713,61 @@ def test_search_return_fields(client):
assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"]


@pytest.mark.redismod
def test_binary_and_text_fields(binary_client):
assert (
binary_client.get_connection_kwargs()["decode_responses"] is False
), "This feature is only available when decode_responses is False"

fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)

index_name = "mixed_index"
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
binary_client.hset(f"{index_name}:1", mapping=mixed_data)

schema = (
TagField("first_name"),
VectorField(
"embeddings_bio",
algorithm="HNSW",
attributes={
"TYPE": "FLOAT32",
"DIM": 4,
"DISTANCE_METRIC": "COSINE",
},
),
)

binary_client.ft(index_name).create_index(
fields=schema,
definition=IndexDefinition(
prefix=[f"{index_name}:"], index_type=IndexType.HASH
),
)

bytes_person_1 = binary_client.hget(f"{index_name}:1", "vector_emb")
decoded_vec_from_hash = np.frombuffer(bytes_person_1, dtype=np.float32)
assert np.array_equal(decoded_vec_from_hash, fake_vec), "The vectors are not equal"

query = (
Query("*")
.return_field("vector_emb", decode_field=False)
.return_field("first_name", decode_field=True)
)
docs = binary_client.ft(index_name).search(query=query, query_params={}).docs
decoded_vec_from_search_results = np.frombuffer(
docs[0]["vector_emb"], dtype=np.float32
)

assert np.array_equal(
decoded_vec_from_search_results, fake_vec
), "The vectors are not equal"

assert (
docs[0]["first_name"] == mixed_data["first_name"]
), "The first is not decoded correctly"

gerzse marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.redismod
def test_synupdate(client):
definition = IndexDefinition(index_type=IndexType.HASH)
Expand Down
Loading