Skip to content

Commit

Permalink
[CLUSTER] Fix scan command cursors & Fix scan_iter (#2054)
Browse files Browse the repository at this point in the history
* cluster/scan: fix return cursor & change default node to primaries

* cluster/scan_iter: fix iteration

Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
  • Loading branch information
utkarshgupta137 and dvora-h authored Mar 23, 2022
1 parent 032fd22 commit 827dcde
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 21 deletions.
4 changes: 3 additions & 1 deletion CHANGES
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@

* Add `items` parameter to `hset` signature
* Create codeql-analysis.yml (#1988). Thanks @chayim
* Create codeql-analysis.yml (#1988). Thanks @chayim
* Add limited support for Lua scripting with RedisCluster
* Implement `.lock()` method on RedisCluster
* Fix cursor returned by SCAN for RedisCluster & change default target to PRIMARIES
* Fix scan_iter for RedisCluster
* Remove verbose logging when initializing ClusterPubSub, ClusterPipeline or RedisCluster

* 4.1.3 (Feb 8, 2022)
Expand Down
16 changes: 10 additions & 6 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from collections import OrderedDict

from redis.client import CaseInsensitiveDict, PubSub, Redis
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
from redis.commands import CommandsParser, RedisClusterCommands
from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
Expand Down Expand Up @@ -51,10 +51,14 @@ def get_connection(redis_node, *args, **options):


def parse_scan_result(command, res, **options):
keys_list = []
for primary_res in res.values():
keys_list += primary_res[1]
return 0, keys_list
cursors = {}
ret = []
for node_name, response in res.items():
cursor, r = parse_scan(response, **options)
cursors[node_name] = cursor
ret += r

return cursors, ret


def parse_pubsub_numsub(command, res, **options):
Expand Down Expand Up @@ -244,7 +248,6 @@ class RedisCluster(RedisClusterCommands):
"INFO",
"SHUTDOWN",
"KEYS",
"SCAN",
"DBSIZE",
"BGSAVE",
"SLOWLOG GET",
Expand Down Expand Up @@ -298,6 +301,7 @@ class RedisCluster(RedisClusterCommands):
"FUNCTION LIST",
"FUNCTION LOAD",
"FUNCTION RESTORE",
"SCAN",
"SCRIPT EXISTS",
"SCRIPT FLUSH",
"SCRIPT LOAD",
Expand Down
38 changes: 38 additions & 0 deletions redis/commands/cluster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Iterator, Union

from redis.crc import key_slot
from redis.exceptions import RedisClusterException, RedisError
from redis.typing import PatternT

from .core import (
ACLCommands,
Expand Down Expand Up @@ -206,6 +209,41 @@ def stralgo(
**kwargs,
)

def scan_iter(
self,
match: Union[PatternT, None] = None,
count: Union[int, None] = None,
_type: Union[str, None] = None,
**kwargs,
) -> Iterator:
# Do the first query with cursor=0 for all nodes
cursors, data = self.scan(match=match, count=count, _type=_type, **kwargs)
yield from data

cursors = {name: cursor for name, cursor in cursors.items() if cursor != 0}
if cursors:
# Get nodes by name
nodes = {name: self.get_node(node_name=name) for name in cursors.keys()}

# Iterate over each node till its cursor is 0
kwargs.pop("target_nodes", None)
while cursors:
for name, cursor in cursors.items():
cur, data = self.scan(
cursor=cursor,
match=match,
count=count,
_type=_type,
target_nodes=nodes[name],
**kwargs,
)
yield from data
cursors[name] = cur[name]

cursors = {
name: cursor for name, cursor in cursors.items() if cursor != 0
}


class RedisClusterCommands(
ClusterMultiKeyCommands,
Expand Down
59 changes: 45 additions & 14 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,29 +1773,60 @@ def test_cluster_scan(self, r):
r.set("a", 1)
r.set("b", 2)
r.set("c", 3)
cursor, keys = r.scan(target_nodes="primaries")
assert cursor == 0
assert set(keys) == {b"a", b"b", b"c"}
_, keys = r.scan(match="a", target_nodes="primaries")
assert set(keys) == {b"a"}

for target_nodes, nodes in zip(
["primaries", "replicas"], [r.get_primaries(), r.get_replicas()]
):
cursors, keys = r.scan(target_nodes=target_nodes)
assert sorted(keys) == [b"a", b"b", b"c"]
assert sorted(cursors.keys()) == sorted(node.name for node in nodes)
assert all(cursor == 0 for cursor in cursors.values())

cursors, keys = r.scan(match="a*", target_nodes=target_nodes)
assert sorted(keys) == [b"a"]
assert sorted(cursors.keys()) == sorted(node.name for node in nodes)
assert all(cursor == 0 for cursor in cursors.values())

@skip_if_server_version_lt("6.0.0")
def test_cluster_scan_type(self, r):
r.sadd("a-set", 1)
r.sadd("b-set", 1)
r.sadd("c-set", 1)
r.hset("a-hash", "foo", 2)
r.lpush("a-list", "aux", 3)
_, keys = r.scan(match="a*", _type="SET", target_nodes="primaries")
assert set(keys) == {b"a-set"}

for target_nodes, nodes in zip(
["primaries", "replicas"], [r.get_primaries(), r.get_replicas()]
):
cursors, keys = r.scan(_type="SET", target_nodes=target_nodes)
assert sorted(keys) == [b"a-set", b"b-set", b"c-set"]
assert sorted(cursors.keys()) == sorted(node.name for node in nodes)
assert all(cursor == 0 for cursor in cursors.values())

cursors, keys = r.scan(_type="SET", match="a*", target_nodes=target_nodes)
assert sorted(keys) == [b"a-set"]
assert sorted(cursors.keys()) == sorted(node.name for node in nodes)
assert all(cursor == 0 for cursor in cursors.values())

@skip_if_server_version_lt("2.8.0")
def test_cluster_scan_iter(self, r):
r.set("a", 1)
r.set("b", 2)
r.set("c", 3)
keys = list(r.scan_iter(target_nodes="primaries"))
assert set(keys) == {b"a", b"b", b"c"}
keys = list(r.scan_iter(match="a", target_nodes="primaries"))
assert set(keys) == {b"a"}
keys_all = []
keys_1 = []
for i in range(100):
s = str(i)
r.set(s, 1)
keys_all.append(s.encode("utf-8"))
if s.startswith("1"):
keys_1.append(s.encode("utf-8"))
keys_all.sort()
keys_1.sort()

for target_nodes in ["primaries", "replicas"]:
keys = r.scan_iter(target_nodes=target_nodes)
assert sorted(keys) == keys_all

keys = r.scan_iter(match="1*", target_nodes=target_nodes)
assert sorted(keys) == keys_1

def test_cluster_randomkey(self, r):
node = r.get_node_from_key("{foo}")
Expand Down

0 comments on commit 827dcde

Please sign in to comment.