Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ def lock(
name,
timeout=None,
sleep=0.1,
blocking=True,
blocking_timeout=None,
lock_class=None,
thread_local=True,
Expand All @@ -1104,6 +1105,12 @@ def lock(
when the lock is in blocking mode and another client is currently
holding the lock.

``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.

``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
Expand Down Expand Up @@ -1146,6 +1153,7 @@ def lock(
name,
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
Expand Down
8 changes: 8 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@ def lock(
name,
timeout=None,
sleep=0.1,
blocking=True,
blocking_timeout=None,
lock_class=None,
thread_local=True,
Expand All @@ -816,6 +817,12 @@ def lock(
when the lock is in blocking mode and another client is currently
holding the lock.

``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.

``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
Expand Down Expand Up @@ -858,6 +865,7 @@ def lock(
name,
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
Expand Down
68 changes: 43 additions & 25 deletions redis/lock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import threading
import time as mod_time
import uuid
from types import SimpleNamespace
from types import SimpleNamespace, TracebackType
from typing import Optional, Type

from redis.exceptions import LockError, LockNotOwnedError
from redis.typing import Number


class Lock:
Expand Down Expand Up @@ -74,12 +76,13 @@ class Lock:
def __init__(
self,
redis,
name,
timeout=None,
sleep=0.1,
blocking=True,
blocking_timeout=None,
thread_local=True,
name: str,
*,
timeout: Optional[Number] = None,
sleep: Number = 0.1,
blocking: bool = True,
blocking_timeout: Optional[Number] = None,
thread_local: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
Expand Down Expand Up @@ -142,7 +145,7 @@ def __init__(
self.local.token = None
self.register_scripts()

def register_scripts(self):
def register_scripts(self) -> None:
cls = self.__class__
client = self.redis
if cls.lua_release is None:
Expand All @@ -152,15 +155,27 @@ def register_scripts(self):
if cls.lua_reacquire is None:
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)

def __enter__(self):
def __enter__(self) -> "Lock":
if self.acquire():
return self
raise LockError("Unable to acquire lock within the time specified")

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.release()

def acquire(self, blocking=None, blocking_timeout=None, token=None):
def acquire(
self,
*,
sleep: Optional[Number] = None,
blocking: Optional[bool] = None,
blocking_timeout: Optional[Number] = None,
token: Optional[str] = None,
):
"""
Use Redis to hold a shared, distributed lock named ``name``.
Returns True once the lock is acquired.
Expand All @@ -176,7 +191,8 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None):
object with the default encoding. If a token isn't specified, a UUID
will be generated.
"""
sleep = self.sleep
if sleep is None:
sleep = self.sleep
if token is None:
token = uuid.uuid1().hex.encode()
else:
Expand All @@ -200,7 +216,7 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None):
return False
mod_time.sleep(sleep)

def do_acquire(self, token):
def do_acquire(self, token: str) -> bool:
if self.timeout:
# convert to milliseconds
timeout = int(self.timeout * 1000)
Expand All @@ -210,13 +226,13 @@ def do_acquire(self, token):
return True
return False

def locked(self):
def locked(self) -> bool:
"""
Returns True if this key is locked by any process, otherwise False.
"""
return self.redis.get(self.name) is not None

def owned(self):
def owned(self) -> bool:
"""
Returns True if this key is locked by this lock, otherwise False.
"""
Expand All @@ -228,21 +244,23 @@ def owned(self):
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token

def release(self):
"Releases the already acquired lock"
def release(self) -> None:
"""
Releases the already acquired lock
"""
expected_token = self.local.token
if expected_token is None:
raise LockError("Cannot release an unlocked lock")
self.local.token = None
self.do_release(expected_token)

def do_release(self, expected_token):
def do_release(self, expected_token: str) -> None:
if not bool(
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
):
raise LockNotOwnedError("Cannot release a lock" " that's no longer owned")

def extend(self, additional_time, replace_ttl=False):
def extend(self, additional_time: int, replace_ttl: bool = False) -> bool:
"""
Adds more time to an already acquired lock.

Expand All @@ -259,19 +277,19 @@ def extend(self, additional_time, replace_ttl=False):
raise LockError("Cannot extend a lock with no timeout")
return self.do_extend(additional_time, replace_ttl)

def do_extend(self, additional_time, replace_ttl):
def do_extend(self, additional_time: int, replace_ttl: bool) -> bool:
additional_time = int(additional_time * 1000)
if not bool(
self.lua_extend(
keys=[self.name],
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
args=[self.local.token, additional_time, "1" if replace_ttl else "0"],
client=self.redis,
)
):
raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned")
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
return True

def reacquire(self):
def reacquire(self) -> bool:
"""
Resets a TTL of an already acquired lock back to a timeout value.
"""
Expand All @@ -281,12 +299,12 @@ def reacquire(self):
raise LockError("Cannot reacquire a lock with no timeout")
return self.do_reacquire()

def do_reacquire(self):
def do_reacquire(self) -> bool:
timeout = int(self.timeout * 1000)
if not bool(
self.lua_reacquire(
keys=[self.name], args=[self.local.token, timeout], client=self.redis
)
):
raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned")
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
return True
1 change: 1 addition & 0 deletions redis/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from redis.connection import ConnectionPool


Number = Union[int, float]
EncodedT = Union[bytes, memoryview]
DecodedT = Union[str, int, float]
EncodableT = Union[EncodedT, DecodedT]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ def test_context_manager(self, r):
assert r.get("foo") == lock.local.token
assert r.get("foo") is None

def test_context_manager_blocking_timeout(self, r):
with self.get_lock(r, "foo", blocking=False):
bt = 0.4
sleep = 0.05
lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt)
start = time.monotonic()
assert not lock2.acquire()
# The elapsed duration should be less than the total blocking_timeout
assert bt > (time.monotonic() - start) > bt - sleep

def test_context_manager_raises_when_locked_not_acquired(self, r):
r.set("foo", "bar")
with pytest.raises(LockError):
Expand Down Expand Up @@ -221,6 +231,16 @@ def test_reacquiring_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockNotOwnedError):
lock.reacquire()

def test_context_manager_reacquiring_lock_with_no_timeout_raises_error(self, r):
with self.get_lock(r, "foo", timeout=None, blocking=False) as lock:
with pytest.raises(LockError):
lock.reacquire()

def test_context_manager_reacquiring_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockNotOwnedError):
with self.get_lock(r, "foo", timeout=10, blocking=False):
r.set("foo", "a")


class TestLockClassSelection:
def test_lock_class_argument(self, r):
Expand Down