From 5c99e27459a047cd1334e1b87fb0623ac2c881db Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Tue, 3 May 2022 14:04:14 +0300 Subject: [PATCH] ACL SETUSER - add selectors and key based permissions (#2161) * acl setuser * async tests Co-authored-by: Chayim --- redis/client.py | 13 ++++++++++ redis/commands/core.py | 32 +++++++++++++++++++++--- tests/test_asyncio/test_commands.py | 2 ++ tests/test_commands.py | 38 ++++++++++++++++++++++++++--- 4 files changed, 77 insertions(+), 8 deletions(-) diff --git a/redis/client.py b/redis/client.py index e44f5abeb2..87c79913d3 100755 --- a/redis/client.py +++ b/redis/client.py @@ -580,6 +580,19 @@ def parse_acl_getuser(response, **options): data["flags"] = list(map(str_if_bytes, data["flags"])) data["passwords"] = list(map(str_if_bytes, data["passwords"])) data["commands"] = str_if_bytes(data["commands"]) + if isinstance(data["keys"], str) or isinstance(data["keys"], bytes): + data["keys"] = list(str_if_bytes(data["keys"]).split(" ")) + if data["keys"] == [""]: + data["keys"] = [] + if "channels" in data: + if isinstance(data["channels"], str) or isinstance(data["channels"], bytes): + data["channels"] = list(str_if_bytes(data["channels"]).split(" ")) + if data["channels"] == [""]: + data["channels"] = [] + if "selectors" in data: + data["selectors"] = [ + list(map(str_if_bytes, selector)) for selector in data["selectors"] + ] # split 'commands' into separate 'categories' and 'commands' lists commands, categories = [], [] diff --git a/redis/commands/core.py b/redis/commands/core.py index 8bbcda3a69..6526ef167a 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -186,9 +186,11 @@ def acl_setuser( nopass: bool = False, passwords: Union[str, Iterable[str], None] = None, hashed_passwords: Union[str, Iterable[str], None] = None, - categories: Union[Iterable[str], None] = None, - commands: Union[Iterable[str], None] = None, - keys: Union[Iterable[KeyT], None] = None, + categories: Optional[Iterable[str]] = None, + commands: Optional[Iterable[str]] = None, + keys: Optional[Iterable[KeyT]] = None, + channels: Optional[Iterable[ChannelT]] = None, + selectors: Optional[Iterable[Tuple[str, KeyT]]] = None, reset: bool = False, reset_keys: bool = False, reset_passwords: bool = False, @@ -342,7 +344,29 @@ def acl_setuser( if keys: for key in keys: key = encoder.encode(key) - pieces.append(b"~%s" % key) + if not key.startswith(b"%") and not key.startswith(b"~"): + key = b"~%s" % key + pieces.append(key) + + if channels: + for channel in channels: + channel = encoder.encode(channel) + pieces.append(b"&%s" % channel) + + if selectors: + for cmd, key in selectors: + cmd = encoder.encode(cmd) + if not cmd.startswith(b"+") and not cmd.startswith(b"-"): + raise DataError( + f'Command "{encoder.decode(cmd, force=True)}" ' + 'must be prefixed with "+" or "-"' + ) + + key = encoder.encode(key) + if not key.startswith(b"%") and not key.startswith(b"~"): + key = b"~%s" % key + + pieces.append(b"(%s %s)" % (cmd, key)) return self.execute_command("ACL SETUSER", *pieces, **kwargs) diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 78220406eb..dee8755a17 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -109,6 +109,7 @@ async def test_acl_genpass(self, r: redis.Redis): assert isinstance(password, str) @skip_if_server_version_lt(REDIS_6_VERSION) + @skip_if_server_version_gte("7.0.0") async def test_acl_getuser_setuser(self, r: redis.Redis, request, event_loop): username = "redis-py-user" @@ -224,6 +225,7 @@ def teardown(): assert len((await r.acl_getuser(username))["passwords"]) == 1 @skip_if_server_version_lt(REDIS_6_VERSION) + @skip_if_server_version_gte("7.0.0") async def test_acl_list(self, r: redis.Redis, request, event_loop): username = "redis-py-user" diff --git a/tests/test_commands.py b/tests/test_commands.py index 59754123ac..b7287b4fea 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -120,8 +120,14 @@ def test_acl_cat_with_category(self, r): @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() - def test_acl_dryrun(self, r): + def test_acl_dryrun(self, r, request): username = "redis-py-user" + + def teardown(): + r.acl_deluser(username) + + request.addfinalizer(teardown) + r.acl_setuser( username, keys=["*"], @@ -171,7 +177,7 @@ def test_acl_genpass(self, r): r.acl_genpass(555) assert isinstance(password, str) - @skip_if_server_version_lt("6.0.0") + @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() def test_acl_getuser_setuser(self, r, request): username = "redis-py-user" @@ -217,7 +223,7 @@ def teardown(): assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True assert "on" in acl["flags"] - assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 # test reset=False keeps existing ACL and applies new ACL on top @@ -243,7 +249,7 @@ def teardown(): assert set(acl["commands"]) == {"+get", "+mget"} assert acl["enabled"] is True assert "on" in acl["flags"] - assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 # test removal of passwords @@ -278,6 +284,30 @@ def teardown(): ) assert len(r.acl_getuser(username)["passwords"]) == 1 + # test selectors + assert r.acl_setuser( + username, + enabled=True, + reset=True, + passwords=["+pass1", "+pass2"], + categories=["+set", "+@hash", "-geo"], + commands=["+get", "+mget", "-hset"], + keys=["cache:*", "objects:*"], + channels=["message:*"], + selectors=[("+set", "%W~app*")], + ) + acl = r.acl_getuser(username) + assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["commands"]) == {"+get", "+mget", "-hset"} + assert acl["enabled"] is True + assert "on" in acl["flags"] + assert set(acl["keys"]) == {"~cache:*", "~objects:*"} + assert len(acl["passwords"]) == 2 + assert set(acl["channels"]) == {"&message:*"} + assert acl["selectors"] == [ + ["commands", "-@all +set", "keys", "%W~app*", "channels", ""] + ] + @skip_if_server_version_lt("6.0.0") def test_acl_help(self, r): res = r.acl_help()