Skip to content

Commit 24714ae

Browse files
committed
Added async API
1 parent 87a1ffa commit 24714ae

File tree

7 files changed

+65
-21
lines changed

7 files changed

+65
-21
lines changed

redis/asyncio/client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
list_or_args,
5454
)
5555
from redis.credentials import CredentialProvider
56+
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType
5657
from redis.exceptions import (
5758
ConnectionError,
5859
ExecAbortError,
@@ -233,6 +234,7 @@ def __init__(
233234
redis_connect_func=None,
234235
credential_provider: Optional[CredentialProvider] = None,
235236
protocol: Optional[int] = 2,
237+
event_dispatcher: Optional[EventDispatcher] = EventDispatcher(),
236238
):
237239
"""
238240
Initialize a new Redis client.
@@ -320,9 +322,19 @@ def __init__(
320322
# This arg only used if no pool is passed in
321323
self.auto_close_connection_pool = auto_close_connection_pool
322324
connection_pool = ConnectionPool(**kwargs)
325+
event_dispatcher.dispatch(AfterPooledConnectionsInstantiationEvent(
326+
[connection_pool],
327+
ClientType.ASYNC,
328+
credential_provider
329+
))
323330
else:
324331
# If a pool is passed in, do not close it
325332
self.auto_close_connection_pool = False
333+
event_dispatcher.dispatch(AfterPooledConnectionsInstantiationEvent(
334+
[connection_pool],
335+
ClientType.ASYNC,
336+
credential_provider
337+
))
326338

327339
self.connection_pool = connection_pool
328340
self.single_connection_client = single_connection_client

redis/asyncio/connection.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,6 @@ async def send_packed_command(
504504

505505
async def send_command(self, *args: Any, **kwargs: Any) -> None:
506506
"""Pack and send a command to the Redis server"""
507-
if isinstance(self.credential_provider, StreamingCredentialProvider):
508-
await self._event_dispatcher.dispatch_async(
509-
AsyncBeforeCommandExecutionEvent(args, self._init_auth_args, self, self.credential_provider)
510-
)
511507
await self.send_packed_command(
512508
self.pack_command(*args), check_health=kwargs.get("check_health", True)
513509
)
@@ -1045,6 +1041,7 @@ def __init__(
10451041
self._available_connections: List[AbstractConnection] = []
10461042
self._in_use_connections: Set[AbstractConnection] = set()
10471043
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1044+
self._lock = asyncio.Lock()
10481045

10491046
def __repr__(self):
10501047
return (
@@ -1064,13 +1061,14 @@ def can_get_connection(self) -> bool:
10641061
)
10651062

10661063
async def get_connection(self, command_name, *keys, **options):
1067-
"""Get a connected connection from the pool"""
1068-
connection = self.get_available_connection()
1069-
try:
1070-
await self.ensure_connection(connection)
1071-
except BaseException:
1072-
await self.release(connection)
1073-
raise
1064+
async with self._lock:
1065+
"""Get a connected connection from the pool"""
1066+
connection = self.get_available_connection()
1067+
try:
1068+
await self.ensure_connection(connection)
1069+
except BaseException:
1070+
await self.release(connection)
1071+
raise
10741072

10751073
return connection
10761074

@@ -1153,6 +1151,14 @@ def set_retry(self, retry: "Retry") -> None:
11531151
for conn in self._in_use_connections:
11541152
conn.retry = retry
11551153

1154+
async def re_auth_callback(self, token):
1155+
async with self._lock:
1156+
for conn in self._available_connections:
1157+
await conn.send_command(
1158+
'AUTH', token.try_get('oid'), token.get_value()
1159+
)
1160+
await conn.read_response()
1161+
11561162

11571163
class BlockingConnectionPool(ConnectionPool):
11581164
"""

redis/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
UnixDomainSocketConnection,
2828
)
2929
from redis.credentials import CredentialProvider
30-
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent
30+
from redis.event import EventDispatcher, AfterPooledConnectionsInstantiationEvent, ClientType
3131
from redis.exceptions import (
3232
ConnectionError,
3333
ExecAbortError,

redis/cluster.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from redis.commands.helpers import list_or_args
1616
from redis.connection import ConnectionPool, DefaultParser, parse_url
1717
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
18-
from redis.event import EventDispatcher, EventDispatcherInterface, AfterPooledConnectionsInstantiationEvent
18+
from redis.event import EventDispatcher, EventDispatcherInterface, AfterPooledConnectionsInstantiationEvent, ClientType
1919
from redis.exceptions import (
2020
AskError,
2121
AuthenticationError,
@@ -1494,7 +1494,11 @@ def create_redis_connections(self, nodes):
14941494
connection_pools.append(node.redis_connection.connection_pool)
14951495

14961496
self._event_dispatcher.dispatch(
1497-
AfterPooledConnectionsInstantiationEvent(connection_pools, self._credential_provider)
1497+
AfterPooledConnectionsInstantiationEvent(
1498+
connection_pools,
1499+
ClientType.SYNC,
1500+
self._credential_provider
1501+
)
14981502
)
14991503

15001504
def create_redis_node(self, host, port, **kwargs):

redis/connection.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -558,10 +558,6 @@ def send_packed_command(self, command, check_health=True):
558558

559559
def send_command(self, *args, **kwargs):
560560
"""Pack and send a command to the Redis server"""
561-
if isinstance(self.credential_provider, StreamingCredentialProvider):
562-
self._event_dispatcher.dispatch(
563-
BeforeCommandExecutionEvent(args, self._init_auth_args, self, self.credential_provider)
564-
)
565561
self.send_packed_command(
566562
self._command_packer.pack(*args),
567563
check_health=kwargs.get("check_health", True),

redis/credentials.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import logging
12
from abc import ABC, abstractmethod
23
from typing import Optional, Tuple, Union, Callable, Any
34

5+
logger = logging.getLogger(__name__)
6+
47

58
class CredentialProvider:
69
"""
@@ -11,7 +14,8 @@ def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
1114
raise NotImplementedError("get_credentials must be implemented")
1215

1316
async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
14-
raise NotImplementedError("get_credentials_async must be implemented")
17+
logger.warning("This method is added for backward compatability. Please override it in your implementation.")
18+
return self.get_credentials()
1519

1620

1721
class StreamingCredentialProvider(CredentialProvider, ABC):

redis/event.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,33 @@ class AsyncBeforeCommandExecutionEvent(BeforeCommandExecutionEvent):
9898
pass
9999

100100

101+
class ClientType(Enum):
102+
SYNC = "sync",
103+
ASYNC = "async",
104+
105+
101106
class AfterPooledConnectionsInstantiationEvent:
102107
"""
103108
Event that will be fired after pooled connection instances was created.
104109
"""
105110
def __init__(
106111
self,
107112
connection_pools: List,
113+
client_type: ClientType,
108114
credential_provider: Optional[CredentialProvider] = None,
109115
):
110116
self._connection_pools = connection_pools
117+
self._client_type = client_type
111118
self._credential_provider = credential_provider
112119

113120
@property
114121
def connection_pools(self):
115122
return self._connection_pools
116123

124+
@property
125+
def client_type(self) -> ClientType:
126+
return self._client_type
127+
117128
@property
118129
def credential_provider(self) -> Union[CredentialProvider, None]:
119130
return self._credential_provider
@@ -127,6 +138,9 @@ def __init__(self):
127138
self._current_cred = None
128139

129140
def listen(self, event: BeforeCommandExecutionEvent):
141+
if event.command[0] == 'AUTH':
142+
return
143+
130144
if self._current_cred is None:
131145
self._current_cred = event.initial_cred
132146

@@ -168,8 +182,16 @@ def __init__(self):
168182
def listen(self, event: AfterPooledConnectionsInstantiationEvent):
169183
if isinstance(event.credential_provider, StreamingCredentialProvider):
170184
self._event = event
171-
event.credential_provider.on_next(self._re_auth)
185+
186+
if event.client_type == ClientType.SYNC:
187+
event.credential_provider.on_next(self._re_auth)
188+
else:
189+
event.credential_provider.on_next(self._re_auth_async)
172190

173191
def _re_auth(self, token):
174192
for pool in self._event.connection_pools:
175-
pool.re_auth_callback(token)
193+
pool.re_auth_callback(token)
194+
195+
async def _re_auth_async(self, token):
196+
for pool in self._event.connection_pools:
197+
await pool.re_auth_callback(token)

0 commit comments

Comments
 (0)