Skip to content

Commit b32ed95

Browse files
committed
fix: handle ClusterPipeline AttributeError in get_protocol_version (#365)
Wrap redis-py's get_protocol_version to catch AttributeError when ClusterPipeline objects lack nodes_manager attribute. Returns None on error, causing NEVER_DECODE to be set (safe fallback behavior). This fixes crashes when using SearchIndex.load() with Redis Cluster where batch operations create ClusterPipeline objects internally. Fixes #365
1 parent 547085a commit b32ed95

File tree

6 files changed

+264
-2
lines changed

6 files changed

+264
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,4 @@ tests/data
230230
.cursor
231231
.junie
232232
.undodir
233+
.claude/settings.local.json

redisvl/index/index.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949

5050
from redis import __version__ as redis_version
5151
from redis.client import NEVER_DECODE
52-
from redis.commands.helpers import get_protocol_version # type: ignore
52+
53+
from redisvl.utils.redis_protocol import get_protocol_version
5354

5455
# Redis 5.x compatibility (6 fixed the import path)
5556
if redis_version.startswith("5"):

redisvl/redis/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from redis import __version__ as redis_version
88
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
99
from redis.client import NEVER_DECODE, Pipeline
10-
from redis.commands.helpers import get_protocol_version
1110
from redis.commands.search import AsyncSearch, Search
1211
from redis.commands.search.commands import (
1312
CREATE_CMD,
@@ -23,6 +22,8 @@
2322
)
2423
from redis.commands.search.field import Field
2524

25+
from redisvl.utils.redis_protocol import get_protocol_version
26+
2627
# Redis 5.x compatibility (6 fixed the import path)
2728
if redis_version.startswith("5"):
2829
from redis.commands.search.indexDefinition import ( # type: ignore[import-untyped]

redisvl/utils/redis_protocol.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Wrapper for redis-py's get_protocol_version to handle edge cases.
3+
4+
This fixes issue #365 where ClusterPipeline objects may not have nodes_manager attribute.
5+
"""
6+
7+
from typing import Optional, Union
8+
9+
from redis.asyncio.cluster import ClusterPipeline as AsyncClusterPipeline
10+
from redis.cluster import ClusterPipeline
11+
from redis.commands.helpers import get_protocol_version as redis_get_protocol_version
12+
13+
14+
def get_protocol_version(client) -> Optional[str]:
15+
"""
16+
Safe wrapper for redis-py's get_protocol_version that handles edge cases.
17+
18+
The main issue is that ClusterPipeline objects may not always have a
19+
nodes_manager attribute properly set, causing AttributeError.
20+
21+
Args:
22+
client: Redis client, pipeline, or cluster pipeline object
23+
24+
Returns:
25+
Protocol version string ("2" or "3") or None if unable to determine
26+
"""
27+
try:
28+
# Use redis-py's function - it returns None for unknown types
29+
result = redis_get_protocol_version(client)
30+
return result
31+
except AttributeError:
32+
# This happens when ClusterPipeline doesn't have nodes_manager
33+
# Return None to let the caller decide what to do
34+
return None
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""
2+
Tests ClusterPipeline
3+
"""
4+
5+
import pytest
6+
from redis.cluster import RedisCluster
7+
from redis.commands.helpers import get_protocol_version
8+
9+
from redisvl.index import SearchIndex
10+
from redisvl.schema import IndexSchema
11+
12+
13+
@pytest.mark.requires_cluster
14+
def test_real_cluster_pipeline_get_protocol_version(redis_cluster_url):
15+
"""
16+
Test that get_protocol_version works with ClusterPipeline
17+
"""
18+
# Create REAL Redis Cluster client
19+
cluster_client = RedisCluster.from_url(redis_cluster_url)
20+
21+
# Create REAL pipeline from cluster
22+
pipeline = cluster_client.pipeline()
23+
24+
# This is the actual line that was failing in issue #365
25+
# If our fix works, this should NOT raise AttributeError
26+
protocol = get_protocol_version(pipeline)
27+
28+
# Protocol should be a string ("2" or "3") or None
29+
assert protocol in [None, "2", "3", 2, 3], f"Unexpected protocol: {protocol}"
30+
31+
# Clean up
32+
cluster_client.close()
33+
34+
35+
@pytest.mark.requires_cluster
36+
def test_real_searchindex_with_cluster_batch_operations(redis_cluster_url):
37+
"""
38+
Test SearchIndex.load() with Redis Cluster.
39+
"""
40+
# Create schema like the user had
41+
schema_dict = {
42+
"index": {"name": "test-real-365", "prefix": "doc", "storage_type": "hash"},
43+
"fields": [
44+
{"name": "id", "type": "tag"},
45+
{"name": "text", "type": "text"},
46+
],
47+
}
48+
49+
schema = IndexSchema.from_dict(schema_dict)
50+
51+
# Create SearchIndex with REAL cluster URL
52+
index = SearchIndex(schema, redis_url=redis_cluster_url)
53+
54+
# Create the index
55+
index.create(overwrite=True)
56+
57+
try:
58+
# Test data like user had
59+
test_data = [{"id": f"item{i}", "text": f"Document {i}"} for i in range(10)]
60+
61+
# See issue #365
62+
# index.load() with batch_size triggers pipeline operations internally
63+
keys = index.load(
64+
data=test_data,
65+
id_field="id",
66+
batch_size=3, # Forces multiple pipeline operations
67+
)
68+
69+
assert len(keys) == 10
70+
assert all(k.startswith("doc:") for k in keys)
71+
72+
finally:
73+
# Clean up
74+
index.delete()
75+
76+
77+
@pytest.mark.requires_cluster
78+
def test_cluster_pipeline_protocol_version_directly():
79+
"""
80+
Test get_protocol_version with various cluster configurations.
81+
"""
82+
import os
83+
84+
# Skip if no cluster available
85+
cluster_url = os.getenv("REDIS_CLUSTER_URL", "redis://localhost:7000")
86+
87+
try:
88+
# Test with default protocol
89+
cluster = RedisCluster.from_url(cluster_url)
90+
pipeline = cluster.pipeline()
91+
92+
# This should work without AttributeError
93+
protocol = get_protocol_version(pipeline)
94+
print(f"Protocol version from real cluster pipeline: {protocol}")
95+
96+
cluster.close()
97+
98+
# Test with explicit RESP2
99+
cluster2 = RedisCluster.from_url(cluster_url, protocol=2)
100+
pipeline2 = cluster2.pipeline()
101+
protocol2 = get_protocol_version(pipeline2)
102+
assert protocol2 in [2, "2", None]
103+
cluster2.close()
104+
105+
# Test with explicit RESP3
106+
cluster3 = RedisCluster.from_url(cluster_url, protocol=3)
107+
pipeline3 = cluster3.pipeline()
108+
protocol3 = get_protocol_version(pipeline3)
109+
assert protocol3 in [3, "3", None]
110+
cluster3.close()
111+
112+
except Exception as e:
113+
pytest.skip(f"Redis Cluster not available: {e}")
114+
115+
116+
@pytest.mark.requires_cluster
117+
def test_batch_search_with_real_cluster(redis_cluster_url):
118+
"""
119+
Test batch_search which uses get_protocol_version internally.
120+
"""
121+
from redisvl.query import FilterQuery
122+
123+
schema_dict = {
124+
"index": {"name": "test-batch-365", "prefix": "batch", "storage_type": "json"},
125+
"fields": [
126+
{"name": "id", "type": "tag"},
127+
{"name": "category", "type": "tag"},
128+
],
129+
}
130+
131+
schema = IndexSchema.from_dict(schema_dict)
132+
index = SearchIndex(schema, redis_url=redis_cluster_url)
133+
134+
index.create(overwrite=True)
135+
136+
try:
137+
# Load test data
138+
data = [{"id": f"doc{i}", "category": f"cat{i % 3}"} for i in range(15)]
139+
index.load(data=data, id_field="id")
140+
141+
# Create multiple queries
142+
queries = [
143+
FilterQuery(filter_expression=f"@category:{{cat{i}}}") for i in range(3)
144+
]
145+
146+
# batch_search internally uses get_protocol_version on pipelines
147+
results = index.batch_search(
148+
[(q.query, q.params) for q in queries], batch_size=2
149+
)
150+
151+
assert len(results) == 3
152+
153+
finally:
154+
index.delete()
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""
2+
Unit tests for the redis_protocol wrapper.
3+
"""
4+
5+
from unittest.mock import Mock
6+
7+
import pytest
8+
from redis.cluster import ClusterPipeline
9+
10+
from redisvl.utils.redis_protocol import get_protocol_version
11+
12+
13+
def test_get_protocol_version_handles_missing_nodes_manager():
14+
"""
15+
Test that get_protocol_version returns None when ClusterPipeline
16+
lacks nodes_manager attribute (issue #365).
17+
"""
18+
# Create a mock ClusterPipeline without nodes_manager
19+
mock_pipeline = Mock(spec=ClusterPipeline)
20+
# Ensure nodes_manager doesn't exist
21+
if hasattr(mock_pipeline, "nodes_manager"):
22+
delattr(mock_pipeline, "nodes_manager")
23+
24+
# Should return None without raising AttributeError
25+
result = get_protocol_version(mock_pipeline)
26+
assert result is None
27+
28+
29+
def test_get_protocol_version_with_valid_nodes_manager():
30+
"""
31+
Test that get_protocol_version works correctly when nodes_manager exists.
32+
"""
33+
# Create a mock ClusterPipeline with nodes_manager
34+
mock_pipeline = Mock(spec=ClusterPipeline)
35+
mock_pipeline.nodes_manager = Mock()
36+
mock_pipeline.nodes_manager.connection_kwargs = {"protocol": "3"}
37+
38+
# Should return the protocol version
39+
result = get_protocol_version(mock_pipeline)
40+
assert result == "3"
41+
42+
43+
def test_get_protocol_version_with_none_client():
44+
"""
45+
Test that get_protocol_version handles None input gracefully.
46+
"""
47+
result = get_protocol_version(None)
48+
assert result is None
49+
50+
51+
def test_protocol_version_affects_never_decode():
52+
"""
53+
Test that None protocol version results in NEVER_DECODE being set.
54+
This is the actual behavior in redisvl code.
55+
"""
56+
from redis.client import NEVER_DECODE
57+
58+
mock_pipeline = Mock(spec=ClusterPipeline)
59+
if hasattr(mock_pipeline, "nodes_manager"):
60+
delattr(mock_pipeline, "nodes_manager")
61+
62+
protocol = get_protocol_version(mock_pipeline)
63+
64+
# This simulates the code in index.py and utils.py
65+
options = {}
66+
if protocol not in ["3", 3]:
67+
options[NEVER_DECODE] = True
68+
69+
# When protocol is None, NEVER_DECODE should be set
70+
assert protocol is None
71+
assert NEVER_DECODE in options

0 commit comments

Comments
 (0)