Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLUSTER] Fix scan command cursors & Fix scan_iter #2054

Merged
merged 3 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGES
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@

* Add `items` parameter to `hset` signature
* Create codeql-analysis.yml (#1988). Thanks @chayim
* Create codeql-analysis.yml (#1988). Thanks @chayim
* Add limited support for Lua scripting with RedisCluster
* Implement `.lock()` method on RedisCluster
* Fix cursor returned by SCAN for RedisCluster & change default target to PRIMARIES
* Fix scan_iter for RedisCluster

* 4.1.3 (Feb 8, 2022)
* Fix flushdb and flushall (#1926)
Expand Down
16 changes: 10 additions & 6 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from collections import OrderedDict

from redis.client import CaseInsensitiveDict, PubSub, Redis
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
from redis.commands import CommandsParser, RedisClusterCommands
from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
Expand Down Expand Up @@ -51,10 +51,14 @@ def get_connection(redis_node, *args, **options):


def parse_scan_result(command, res, **options):
keys_list = []
for primary_res in res.values():
keys_list += primary_res[1]
return 0, keys_list
cursors = {}
ret = []
for node_name, response in res.items():
cursor, r = parse_scan(response, **options)
cursors[node_name] = cursor
ret += r

return cursors, ret


def parse_pubsub_numsub(command, res, **options):
Expand Down Expand Up @@ -244,7 +248,6 @@ class RedisCluster(RedisClusterCommands):
"INFO",
"SHUTDOWN",
"KEYS",
"SCAN",
"DBSIZE",
"BGSAVE",
"SLOWLOG GET",
Expand Down Expand Up @@ -298,6 +301,7 @@ class RedisCluster(RedisClusterCommands):
"FUNCTION LIST",
"FUNCTION LOAD",
"FUNCTION RESTORE",
"SCAN",
"SCRIPT EXISTS",
"SCRIPT FLUSH",
"SCRIPT LOAD",
Expand Down
38 changes: 38 additions & 0 deletions redis/commands/cluster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Iterator, Union

from redis.crc import key_slot
from redis.exceptions import RedisClusterException, RedisError
from redis.typing import PatternT

from .core import (
ACLCommands,
Expand Down Expand Up @@ -206,6 +209,41 @@ def stralgo(
**kwargs,
)

def scan_iter(
self,
match: Union[PatternT, None] = None,
count: Union[int, None] = None,
_type: Union[str, None] = None,
**kwargs,
) -> Iterator:
# Do the first query with cursor=0 for all nodes
cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs)
yield from data

cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0}
if cursors:
# Get nodes by name
nodes = {name: self.get_node(node_name=name) for name in cursors.keys()}

# Iterate over each node till its cursor is 0
kwargs.pop("target_nodes", None)
while cursors:
for name, cursor in cursors.items():
cur, data = self.scan(
cursor=cursor,
match=match,
count=count,
_type=_type,
target_nodes=nodes[name],
**kwargs,
)
yield from data
cursors[name] = cur[name]

cursors = {
name: cursor for name, cursor in cursors.items() if cursor != 0
}


class RedisClusterCommands(
ClusterMultiKeyCommands,
Expand Down
59 changes: 45 additions & 14 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,29 +1773,60 @@ def test_cluster_scan(self, r):
r.set("a", 1)
r.set("b", 2)
r.set("c", 3)
cursor, keys = r.scan(target_nodes="primaries")
assert cursor == 0
assert set(keys) == {b"a", b"b", b"c"}
_, keys = r.scan(match="a", target_nodes="primaries")
assert set(keys) == {b"a"}

for target_nodes, nodes in zip(
["primaries", "replicas"], [r.get_primaries(), r.get_replicas()]
):
cursors, keys = r.scan(target_nodes=target_nodes)
assert sorted(keys) == [b"a", b"b", b"c"]
assert sorted(cursors.keys()) == sorted(node.name for node in nodes)
assert all(cursor == 0 for cursor in cursors.values())

cursors, keys = r.scan(match="a*", target_nodes=target_nodes)
assert sorted(keys) == [b"a"]
assert sorted(cursors.keys()) == sorted(node.name for node in nodes)
assert all(cursor == 0 for cursor in cursors.values())

@skip_if_server_version_lt("6.0.0")
def test_cluster_scan_type(self, r):
r.sadd("a-set", 1)
r.sadd("b-set", 1)
r.sadd("c-set", 1)
r.hset("a-hash", "foo", 2)
r.lpush("a-list", "aux", 3)
_, keys = r.scan(match="a*", _type="SET", target_nodes="primaries")
assert set(keys) == {b"a-set"}

for target_nodes, nodes in zip(
["primaries", "replicas"], [r.get_primaries(), r.get_replicas()]
):
cursors, keys = r.scan(_type="SET", target_nodes=target_nodes)
assert sorted(keys) == [b"a-set", b"b-set", b"c-set"]
assert sorted(cursors.keys()) == sorted(node.name for node in nodes)
assert all(cursor == 0 for cursor in cursors.values())

cursors, keys = r.scan(_type="SET", match="a*", target_nodes=target_nodes)
assert sorted(keys) == [b"a-set"]
assert sorted(cursors.keys()) == sorted(node.name for node in nodes)
assert all(cursor == 0 for cursor in cursors.values())

@skip_if_server_version_lt("2.8.0")
def test_cluster_scan_iter(self, r):
r.set("a", 1)
r.set("b", 2)
r.set("c", 3)
keys = list(r.scan_iter(target_nodes="primaries"))
assert set(keys) == {b"a", b"b", b"c"}
keys = list(r.scan_iter(match="a", target_nodes="primaries"))
assert set(keys) == {b"a"}
keys_all = []
keys_1 = []
for i in range(100):
s = str(i)
r.set(s, 1)
keys_all.append(s.encode("utf-8"))
if s.startswith("1"):
keys_1.append(s.encode("utf-8"))
keys_all.sort()
keys_1.sort()

for target_nodes in ["primaries", "replicas"]:
keys = r.scan_iter(target_nodes=target_nodes)
assert sorted(keys) == keys_all

keys = r.scan_iter(match="1*", target_nodes=target_nodes)
assert sorted(keys) == keys_1

def test_cluster_randomkey(self, r):
node = r.get_node_from_key("{foo}")
Expand Down