Skip to content

Made sync lock consistent and added types to it #2137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ def lock(
name,
timeout=None,
sleep=0.1,
blocking=True,
blocking_timeout=None,
lock_class=None,
thread_local=True,
Expand All @@ -1113,6 +1114,12 @@ def lock(
when the lock is in blocking mode and another client is currently
holding the lock.

``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.

``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
Expand Down Expand Up @@ -1155,6 +1162,7 @@ def lock(
name,
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
Expand Down
8 changes: 8 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,7 @@ def lock(
name,
timeout=None,
sleep=0.1,
blocking=True,
blocking_timeout=None,
lock_class=None,
thread_local=True,
Expand All @@ -781,6 +782,12 @@ def lock(
when the lock is in blocking mode and another client is currently
holding the lock.

``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.

``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
Expand Down Expand Up @@ -823,6 +830,7 @@ def lock(
name,
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
Expand Down
68 changes: 43 additions & 25 deletions redis/lock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import threading
import time as mod_time
import uuid
from types import SimpleNamespace
from types import SimpleNamespace, TracebackType
from typing import Optional, Type

from redis.exceptions import LockError, LockNotOwnedError
from redis.typing import Number


class Lock:
Expand Down Expand Up @@ -74,12 +76,13 @@ class Lock:
def __init__(
self,
redis,
name,
timeout=None,
sleep=0.1,
blocking=True,
blocking_timeout=None,
thread_local=True,
name: str,
*,
timeout: Optional[Number] = None,
sleep: Number = 0.1,
blocking: bool = True,
blocking_timeout: Optional[Number] = None,
thread_local: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
Expand Down Expand Up @@ -142,7 +145,7 @@ def __init__(
self.local.token = None
self.register_scripts()

def register_scripts(self):
def register_scripts(self) -> None:
cls = self.__class__
client = self.redis
if cls.lua_release is None:
Expand All @@ -152,15 +155,27 @@ def register_scripts(self):
if cls.lua_reacquire is None:
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)

def __enter__(self):
def __enter__(self) -> "Lock":
if self.acquire():
return self
raise LockError("Unable to acquire lock within the time specified")

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.release()

def acquire(self, blocking=None, blocking_timeout=None, token=None):
def acquire(
self,
*,
sleep: Optional[Number] = None,
blocking: Optional[bool] = None,
blocking_timeout: Optional[Number] = None,
token: Optional[str] = None,
):
"""
Use Redis to hold a shared, distributed lock named ``name``.
Returns True once the lock is acquired.
Expand All @@ -176,7 +191,8 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None):
object with the default encoding. If a token isn't specified, a UUID
will be generated.
"""
sleep = self.sleep
if sleep is None:
sleep = self.sleep
if token is None:
token = uuid.uuid1().hex.encode()
else:
Expand All @@ -200,7 +216,7 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None):
return False
mod_time.sleep(sleep)

def do_acquire(self, token):
def do_acquire(self, token: str) -> bool:
if self.timeout:
# convert to milliseconds
timeout = int(self.timeout * 1000)
Expand All @@ -210,13 +226,13 @@ def do_acquire(self, token):
return True
return False

def locked(self):
def locked(self) -> bool:
"""
Returns True if this key is locked by any process, otherwise False.
"""
return self.redis.get(self.name) is not None

def owned(self):
def owned(self) -> bool:
"""
Returns True if this key is locked by this lock, otherwise False.
"""
Expand All @@ -228,21 +244,23 @@ def owned(self):
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token

def release(self):
"Releases the already acquired lock"
def release(self) -> None:
"""
Releases the already acquired lock
"""
expected_token = self.local.token
if expected_token is None:
raise LockError("Cannot release an unlocked lock")
self.local.token = None
self.do_release(expected_token)

def do_release(self, expected_token):
def do_release(self, expected_token: str) -> None:
if not bool(
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
):
raise LockNotOwnedError("Cannot release a lock" " that's no longer owned")

def extend(self, additional_time, replace_ttl=False):
def extend(self, additional_time: int, replace_ttl: bool = False) -> bool:
"""
Adds more time to an already acquired lock.

Expand All @@ -259,19 +277,19 @@ def extend(self, additional_time, replace_ttl=False):
raise LockError("Cannot extend a lock with no timeout")
return self.do_extend(additional_time, replace_ttl)

def do_extend(self, additional_time, replace_ttl):
def do_extend(self, additional_time: int, replace_ttl: bool) -> bool:
additional_time = int(additional_time * 1000)
if not bool(
self.lua_extend(
keys=[self.name],
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
args=[self.local.token, additional_time, "1" if replace_ttl else "0"],
client=self.redis,
)
):
raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned")
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
return True

def reacquire(self):
def reacquire(self) -> bool:
"""
Resets a TTL of an already acquired lock back to a timeout value.
"""
Expand All @@ -281,12 +299,12 @@ def reacquire(self):
raise LockError("Cannot reacquire a lock with no timeout")
return self.do_reacquire()

def do_reacquire(self):
def do_reacquire(self) -> bool:
timeout = int(self.timeout * 1000)
if not bool(
self.lua_reacquire(
keys=[self.name], args=[self.local.token, timeout], client=self.redis
)
):
raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned")
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
return True
1 change: 1 addition & 0 deletions redis/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from redis.connection import ConnectionPool, Encoder


Number = Union[int, float]
EncodedT = Union[bytes, memoryview]
DecodedT = Union[str, int, float]
EncodableT = Union[EncodedT, DecodedT]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ def test_context_manager(self, r):
assert r.get("foo") == lock.local.token
assert r.get("foo") is None

def test_context_manager_blocking_timeout(self, r):
with self.get_lock(r, "foo", blocking=False):
bt = 0.4
sleep = 0.05
lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt)
start = time.monotonic()
assert not lock2.acquire()
# The elapsed duration should be less than the total blocking_timeout
assert bt > (time.monotonic() - start) > bt - sleep

def test_context_manager_raises_when_locked_not_acquired(self, r):
r.set("foo", "bar")
with pytest.raises(LockError):
Expand Down Expand Up @@ -221,6 +231,16 @@ def test_reacquiring_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockNotOwnedError):
lock.reacquire()

def test_context_manager_reacquiring_lock_with_no_timeout_raises_error(self, r):
with self.get_lock(r, "foo", timeout=None, blocking=False) as lock:
with pytest.raises(LockError):
lock.reacquire()

def test_context_manager_reacquiring_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockNotOwnedError):
with self.get_lock(r, "foo", timeout=10, blocking=False):
r.set("foo", "a")


class TestLockClassSelection:
def test_lock_class_argument(self, r):
Expand Down