Skip to content
Prev Previous commit
Next Next commit
Added command executor
  • Loading branch information
vladvildanov committed Sep 2, 2025
commit ae42bea09a097855e05fbf62ca757a85df53af5a
184 changes: 177 additions & 7 deletions redis/asyncio/multidb/command_executor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from abc import abstractmethod
from datetime import datetime
from typing import List, Optional, Callable, Any

from redis.asyncio.client import PubSub, Pipeline
from redis.asyncio.multidb.database import Databases, AsyncDatabase
from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database
from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \
ResubscribeOnActiveDatabaseChanged
from redis.asyncio.multidb.failover import AsyncFailoverStrategy
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.multidb.circuit import State as CBState
from redis.asyncio.retry import Retry
from redis.multidb.command_executor import CommandExecutor
from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent
from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL


class AsyncCommandExecutor(CommandExecutor):
Expand Down Expand Up @@ -34,9 +40,8 @@ def active_database(self) -> Optional[AsyncDatabase]:
"""Returns currently active database."""
pass

@active_database.setter
@abstractmethod
def active_database(self, database: AsyncDatabase) -> None:
async def set_active_database(self, database: AsyncDatabase) -> None:
"""Sets the currently active database."""
pass

Expand Down Expand Up @@ -85,11 +90,176 @@ async def execute_transaction(self, transaction: Callable[[Pipeline], None], *wa
pass

@abstractmethod
def execute_pubsub_method(self, method_name: str, *args, **kwargs):
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
"""Executes a given method on active pub/sub."""
pass

@abstractmethod
def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
"""Executes pub/sub run in a thread."""
pass
pass


class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor):
def __init__(
self,
failure_detectors: List[AsyncFailureDetector],
databases: Databases,
command_retry: Retry,
failover_strategy: AsyncFailoverStrategy,
event_dispatcher: EventDispatcherInterface,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
):
"""
Initialize the DefaultCommandExecutor instance.

Args:
failure_detectors: List of failure detector instances to monitor database health
databases: Collection of available databases to execute commands on
command_retry: Retry policy for failed command execution
failover_strategy: Strategy for handling database failover
event_dispatcher: Interface for dispatching events
auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database
"""
super().__init__(auto_fallback_interval)

for fd in failure_detectors:
fd.set_command_executor(command_executor=self)

self._databases = databases
self._failure_detectors = failure_detectors
self._command_retry = command_retry
self._failover_strategy = failover_strategy
self._event_dispatcher = event_dispatcher
self._active_database: Optional[Database] = None
self._active_pubsub: Optional[PubSub] = None
self._active_pubsub_kwargs = {}
self._setup_event_dispatcher()
self._schedule_next_fallback()

@property
def databases(self) -> Databases:
return self._databases

@property
def failure_detectors(self) -> List[AsyncFailureDetector]:
return self._failure_detectors

def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None:
self._failure_detectors.append(failure_detector)

@property
def active_database(self) -> Optional[AsyncDatabase]:
return self._active_database

async def set_active_database(self, database: AsyncDatabase) -> None:
old_active = self._active_database
self._active_database = database

if old_active is not None and old_active is not database:
await self._event_dispatcher.dispatch_async(
AsyncActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs)
)

@property
def active_pubsub(self) -> Optional[PubSub]:
return self._active_pubsub

@active_pubsub.setter
def active_pubsub(self, pubsub: PubSub) -> None:
self._active_pubsub = pubsub

@property
def failover_strategy(self) -> AsyncFailoverStrategy:
return self._failover_strategy

@property
def command_retry(self) -> Retry:
return self._command_retry

async def pubsub(self, **kwargs):
async def callback():
if self._active_pubsub is None:
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
return None

return await self._execute_with_failure_detection(callback)

async def execute_command(self, *args, **options):
async def callback():
return await self._active_database.client.execute_command(*args, **options)

return await self._execute_with_failure_detection(callback, args)

async def execute_pipeline(self, command_stack: tuple):
async def callback():
with self._active_database.client.pipeline() as pipe:
for command, options in command_stack:
await pipe.execute_command(*command, **options)

return await pipe.execute()

return await self._execute_with_failure_detection(callback, command_stack)

async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options):
async def callback():
return await self._active_database.client.transaction(transaction, *watches, **options)

return await self._execute_with_failure_detection(callback)

async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
async def callback():
method = getattr(self.active_pubsub, method_name)
return await method(*args, **kwargs)

return await self._execute_with_failure_detection(callback, *args)

async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
async def callback():
return await self._active_pubsub.run(poll_timeout=sleep_time, **kwargs)

return await self._execute_with_failure_detection(callback)

async def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()):
"""
Execute a commands execution callback with failure detection.
"""
async def wrapper():
# On each retry we need to check active database as it might change.
await self._check_active_database()
return await callback()

return await self._command_retry.call_with_retry(
lambda: wrapper(),
lambda error: self._on_command_fail(error, *cmds),
)

async def _check_active_database(self):
"""
Checks if active a database needs to be updated.
"""
if (
self._active_database is None
or self._active_database.circuit.state != CBState.CLOSED
or (
self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL
and self._next_fallback_attempt <= datetime.now()
)
):
await self.set_active_database(await self._failover_strategy.database())
self._schedule_next_fallback()

async def _on_command_fail(self, error, *args):
await self._event_dispatcher.dispatch_async(AsyncOnCommandsFailEvent(args, error))

def _setup_event_dispatcher(self):
"""
Registers necessary listeners.
"""
failure_listener = RegisterCommandFailure(self._failure_detectors)
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
self._event_dispatcher.register_listeners({
AsyncOnCommandsFailEvent: [failure_listener],
AsyncActiveDatabaseChanged: [resubscribe_listener],
})
2 changes: 0 additions & 2 deletions redis/asyncio/multidb/failover.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

class AsyncFailoverStrategy(ABC):

@property
@abstractmethod
async def database(self) -> AsyncDatabase:
"""Select the database according to the strategy."""
Expand All @@ -33,7 +32,6 @@ def __init__(
self._retry.update_supported_errors([NoValidDatabaseException])
self._databases = WeightedList()

@property
async def database(self) -> AsyncDatabase:
return await self._retry.call_with_retry(
lambda: self._get_active_database(),
Expand Down
7 changes: 5 additions & 2 deletions redis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ async def dispatch_async(self, event: object):
pass

@abstractmethod
def register_listeners(self, mappings: Dict[Type[object], List[EventListenerInterface]]):
def register_listeners(
self,
mappings: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]]
):
"""Register additional listeners."""
pass

Expand Down Expand Up @@ -99,7 +102,7 @@ def dispatch(self, event: object):
listener.listen(event)

async def dispatch_async(self, event: object):
with self._async_lock:
async with self._async_lock:
listeners = self._event_listeners_mapping.get(type(event), [])

for listener in listeners:
Expand Down
29 changes: 27 additions & 2 deletions tests/test_asyncio/test_multidb/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import pytest

from redis.asyncio.multidb.failover import AsyncFailoverStrategy
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
from redis.asyncio.multidb.healthcheck import HealthCheck
from redis.data_structure import WeightedList
from redis.multidb.circuit import State as CBState
from redis.asyncio import Redis
from redis.asyncio.multidb.circuit import AsyncCircuitBreaker
from redis.asyncio.multidb.database import Database
from redis.asyncio.multidb.database import Database, Databases


@pytest.fixture()
Expand All @@ -16,6 +20,18 @@ def mock_client() -> Redis:
def mock_cb() -> AsyncCircuitBreaker:
return Mock(spec=AsyncCircuitBreaker)

@pytest.fixture()
def mock_fd() -> AsyncFailureDetector:
return Mock(spec=AsyncFailureDetector)

@pytest.fixture()
def mock_fs() -> AsyncFailoverStrategy:
return Mock(spec=AsyncFailoverStrategy)

@pytest.fixture()
def mock_hc() -> HealthCheck:
return Mock(spec=HealthCheck)

@pytest.fixture()
def mock_db(request) -> Database:
db = Mock(spec=Database)
Expand Down Expand Up @@ -56,4 +72,13 @@ def mock_db2(request) -> Database:
mock_cb.state = cb.get("state", CBState.CLOSED)

db.circuit = mock_cb
return db
return db


def create_weighted_list(*databases) -> Databases:
dbs = WeightedList()

for db in databases:
dbs.add(db, db.weight)

return dbs
Loading