Skip to content

Commit

Permalink
Enable AsyncIO cluster mode lock (#2446)
Browse files Browse the repository at this point in the history
Co-authored-by: Chayim <chayim@users.noreply.github.com>
  • Loading branch information
KMilhan and chayim authored Nov 9, 2022
1 parent 1cdba63 commit 772079f
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
* ClusterPipeline Doesn't Handle ConnectionError for Dead Hosts (#2225)
* Remove compatibility code for old versions of Hiredis, drop Packaging dependency
* The `deprecated` library is no longer a dependency
* Enable Lock for asyncio cluster mode

* 4.1.3 (Feb 8, 2022)
* Fix flushdb and flushall (#1926)
Expand Down
67 changes: 67 additions & 0 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SSLConnection,
parse_url,
)
from redis.asyncio.lock import Lock
from redis.asyncio.parser import CommandsParser
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
from redis.cluster import (
Expand Down Expand Up @@ -764,6 +765,72 @@ def pipeline(

return ClusterPipeline(self)

def lock(
self,
name: KeyT,
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking_timeout: Optional[float] = None,
lock_class: Optional[Type[Lock]] = None,
thread_local: bool = True,
) -> Lock:
"""
Return a new Lock object using key ``name`` that mimics
the behavior of threading.Lock.
If specified, ``timeout`` indicates a maximum life for the lock.
By default, it will remain locked until release() is called.
``sleep`` indicates the amount of time to sleep per loop iteration
when the lock is in blocking mode and another client is currently
holding the lock.
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
float or integer, both representing the number of seconds to wait.
``lock_class`` forces the specified lock implementation. Note that as
of redis-py 3.0, the only lock class we implement is ``Lock`` (which is
a Lua-based lock). So, it's unlikely you'll need this parameter, unless
you have created your own custom lock class.
``thread_local`` indicates whether the lock token is placed in
thread-local storage. By default, the token is placed in thread local
storage so that a thread only sees its token, not a token set by
another thread. Consider the following timeline:
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
thread-1 sets the token to "abc"
time: 1, thread-2 blocks trying to acquire `my-lock` using the
Lock instance.
time: 5, thread-1 has not yet completed. redis expires the lock
key.
time: 5, thread-2 acquired `my-lock` now that it's available.
thread-2 sets the token to "xyz"
time: 6, thread-1 finishes its work and calls release(). if the
token is *not* stored in thread local storage, then
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
local storage isn't disabled in this case, the worker thread won't see
the token set by the thread that acquired the lock. Our assumption
is that these cases aren't common and as such default to using
thread local storage."""
if lock_class is None:
lock_class = Lock
return lock_class(
self,
name,
timeout=timeout,
sleep=sleep,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)


class ClusterNode:
"""
Expand Down
16 changes: 12 additions & 4 deletions redis/asyncio/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from redis.exceptions import LockError, LockNotOwnedError

if TYPE_CHECKING:
from redis.asyncio import Redis
from redis.asyncio import Redis, RedisCluster


class Lock:
Expand Down Expand Up @@ -77,7 +77,7 @@ class Lock:

def __init__(
self,
redis: "Redis",
redis: Union["Redis", "RedisCluster"],
name: Union[str, bytes, memoryview],
timeout: Optional[float] = None,
sleep: float = 0.1,
Expand Down Expand Up @@ -189,7 +189,11 @@ async def acquire(
if token is None:
token = uuid.uuid1().hex.encode()
else:
encoder = self.redis.connection_pool.get_encoder()
try:
encoder = self.redis.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = self.redis.get_encoder()
token = encoder.encode(token)
if blocking is None:
blocking = self.blocking
Expand Down Expand Up @@ -233,7 +237,11 @@ async def owned(self) -> bool:
# need to always compare bytes to bytes
# TODO: this can be simplified when the context manager is finished
if stored_token and not isinstance(stored_token, bytes):
encoder = self.redis.connection_pool.get_encoder()
try:
encoder = self.redis.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = self.redis.get_encoder()
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token

Expand Down
12 changes: 10 additions & 2 deletions redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4930,7 +4930,11 @@ def __init__(self, registered_client: "Redis", script: ScriptTextT):
if isinstance(script, str):
# We need the encoding from the client in order to generate an
# accurate byte representation of the script
encoder = registered_client.connection_pool.get_encoder()
try:
encoder = registered_client.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = registered_client.get_encoder()
script = encoder.encode(script)
self.sha = hashlib.sha1(script).hexdigest()

Expand Down Expand Up @@ -4975,7 +4979,11 @@ def __init__(self, registered_client: "AsyncRedis", script: ScriptTextT):
if isinstance(script, str):
# We need the encoding from the client in order to generate an
# accurate byte representation of the script
encoder = registered_client.connection_pool.get_encoder()
try:
encoder = registered_client.connection_pool.get_encoder()
except AttributeError:
# Cluster
encoder = registered_client.get_encoder()
script = encoder.encode(script)
self.sha = hashlib.sha1(script).hexdigest()

Expand Down
1 change: 0 additions & 1 deletion tests/test_asyncio/test_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from redis.exceptions import LockError, LockNotOwnedError


@pytest.mark.onlynoncluster
class TestLock:
@pytest_asyncio.fixture()
async def r_decoded(self, create_redis):
Expand Down

0 comments on commit 772079f

Please sign in to comment.