From 4b092d8ffb14e71a28b8103a965b6ea5040ddc15 Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 16 Dec 2024 17:41:54 -0800 Subject: [PATCH 1/2] # This is a combination of 2 commits. # This is the 1st commit message: add tests # This is the commit message #2: maybe someday --- .../prefect-redis/prefect_redis/client.py | 32 +++++--- .../prefect-redis/tests/test_client.py | 81 +++++++++++++++++++ src/prefect/settings/base.py | 5 +- 3 files changed, 106 insertions(+), 12 deletions(-) create mode 100644 src/integrations/prefect-redis/tests/test_client.py diff --git a/src/integrations/prefect-redis/prefect_redis/client.py b/src/integrations/prefect-redis/prefect_redis/client.py index 7534e6bd67ce8..e6950be02006e 100644 --- a/src/integrations/prefect-redis/prefect_redis/client.py +++ b/src/integrations/prefect-redis/prefect_redis/client.py @@ -1,6 +1,6 @@ import asyncio import functools -from typing import Any, Callable +from typing import Any, Callable, Union from pydantic import Field from redis.asyncio import Redis @@ -13,13 +13,21 @@ class RedisSettings(PrefectBaseSettings): - model_config = _build_settings_config(("redis",)) + model_config = _build_settings_config(("redis",), frozen=True) host: str = Field(default="localhost") port: int = Field(default=6379) db: int = Field(default=0) username: str = Field(default="default") password: str = Field(default="") + health_check_interval: int = Field( + default=20, + description="Health check interval for pinging the server; defaults to 20 seconds.", + ) + ssl: bool = Field( + default=False, + description="Whether to use SSL for the Redis connection", + ) CacheKey: TypeAlias = tuple[ @@ -67,14 +75,14 @@ def close_all_cached_connections() -> None: @cached def get_async_redis_client( - host: str | None = None, - port: int | None = None, - db: int | None = None, - password: str | None = None, - username: str | None = None, - health_check_interval: int | None = None, + host: Union[str, None] = None, + port: Union[int, None] = None, + db: Union[int, None] = None, + password: Union[str, None] = None, + username: Union[str, None] = None, + health_check_interval: Union[int, None] = None, decode_responses: bool = True, - ssl: bool | None = None, + ssl: Union[bool, None] = None, ) -> Redis: """Retrieves an async Redis client. @@ -84,7 +92,6 @@ def get_async_redis_client( db: The Redis database to interact with. password: The password for the redis host username: Username for the redis instance - health_check_interval: The health check interval to ping the server on. decode_responses: Whether to decode binary responses from Redis to unicode strings. @@ -99,6 +106,9 @@ def get_async_redis_client( db=db or settings.db, password=password or settings.password, username=username or settings.username, + health_check_interval=health_check_interval or settings.health_check_interval, + ssl=ssl or settings.ssl, + decode_responses=decode_responses, retry_on_timeout=True, ) @@ -116,5 +126,7 @@ def async_redis_from_settings(settings: RedisSettings, **options: Any) -> Redis: db=settings.db, password=settings.password, username=settings.username, + health_check_interval=settings.health_check_interval, + ssl=settings.ssl, **options, ) diff --git a/src/integrations/prefect-redis/tests/test_client.py b/src/integrations/prefect-redis/tests/test_client.py new file mode 100644 index 0000000000000..6a3cfef97a888 --- /dev/null +++ b/src/integrations/prefect-redis/tests/test_client.py @@ -0,0 +1,81 @@ +from unittest.mock import MagicMock, patch + +from prefect_redis.client import ( + RedisSettings, + async_redis_from_settings, + close_all_cached_connections, + get_async_redis_client, +) +from redis.asyncio import Redis + + +def test_redis_settings_defaults(): + """Test that RedisSettings has expected defaults""" + settings = RedisSettings() + assert settings.host == "localhost" + assert settings.port == 6379 + assert settings.db == 0 + assert settings.username == "default" + assert settings.password == "" + assert settings.health_check_interval == 20 + assert settings.ssl is False + + +async def test_get_async_redis_client_defaults(): + """Test that get_async_redis_client creates client with default settings""" + client = get_async_redis_client() + assert isinstance(client, Redis) + assert client.connection_pool.connection_kwargs["host"] == "localhost" + assert client.connection_pool.connection_kwargs["port"] == 6379 + await client.aclose() + + +async def test_get_async_redis_client_custom_params(): + """Test that get_async_redis_client respects custom parameters""" + client = get_async_redis_client( + host="custom.host", + port=6380, + db=1, + username="custom_user", + password="secret", + ) + conn_kwargs = client.connection_pool.connection_kwargs + assert conn_kwargs["host"] == "custom.host" + assert conn_kwargs["port"] == 6380 + assert conn_kwargs["db"] == 1 + assert conn_kwargs["username"] == "custom_user" + assert conn_kwargs["password"] == "secret" + await client.aclose() + + +async def test_async_redis_from_settings(): + """Test creating Redis client from settings object""" + settings = RedisSettings( + host="settings.host", + port=6381, + username="settings_user", + ) + client = async_redis_from_settings(settings) + conn_kwargs = client.connection_pool.connection_kwargs + assert conn_kwargs["host"] == "settings.host" + assert conn_kwargs["port"] == 6381 + assert conn_kwargs["username"] == "settings_user" + await client.aclose() + + +@patch("prefect_redis.client._client_cache") +def test_close_all_cached_connections(mock_cache): + """Test that close_all_cached_connections properly closes all clients""" + mock_client = MagicMock() + mock_loop = MagicMock() + mock_loop.is_closed.return_value = False + + # Mock the coroutines that would be awaited + mock_loop.run_until_complete.return_value = None + + mock_cache.items.return_value = [((None, None, None, mock_loop), mock_client)] + + close_all_cached_connections() + + # Verify run_until_complete was called twice (for disconnect and close) + assert mock_loop.run_until_complete.call_count == 2 diff --git a/src/prefect/settings/base.py b/src/prefect/settings/base.py index 3fa42020bcc69..54486cef8ebf0 100644 --- a/src/prefect/settings/base.py +++ b/src/prefect/settings/base.py @@ -192,7 +192,7 @@ def _add_environment_variables( def _build_settings_config( - path: Tuple[str, ...] = tuple(), + path: Tuple[str, ...] = tuple(), frozen: bool = False ) -> PrefectSettingsConfigDict: env_prefix = f"PREFECT_{'_'.join(path).upper()}_" if path else "PREFECT_" return PrefectSettingsConfigDict( @@ -202,7 +202,8 @@ def _build_settings_config( toml_file="prefect.toml", prefect_toml_table_header=path, pyproject_toml_table_header=("tool", "prefect", *path), - json_schema_extra=_add_environment_variables, + json_schema_extra=_add_environment_variables, # type: ignore + frozen=frozen, ) From 77e87de2a5f6de28eaf72e9f87d1e88eefe65b1e Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Mon, 16 Dec 2024 17:53:11 -0800 Subject: [PATCH 2/2] missed one --- src/integrations/prefect-redis/prefect_redis/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/integrations/prefect-redis/prefect_redis/client.py b/src/integrations/prefect-redis/prefect_redis/client.py index e6950be02006e..1063a55415132 100644 --- a/src/integrations/prefect-redis/prefect_redis/client.py +++ b/src/integrations/prefect-redis/prefect_redis/client.py @@ -34,13 +34,13 @@ class RedisSettings(PrefectBaseSettings): Callable[..., Any], tuple[Any, ...], tuple[tuple[str, Any], ...], - asyncio.AbstractEventLoop | None, + Union[asyncio.AbstractEventLoop, None], ] _client_cache: dict[CacheKey, Redis] = {} -def _running_loop() -> asyncio.AbstractEventLoop | None: +def _running_loop() -> Union[asyncio.AbstractEventLoop, None]: try: return asyncio.get_running_loop() except RuntimeError as e: @@ -62,7 +62,7 @@ def cached_fn(*args: Any, **kwargs: Any) -> Redis: def close_all_cached_connections() -> None: """Close all cached Redis connections.""" - loop: asyncio.AbstractEventLoop | None + loop: Union[asyncio.AbstractEventLoop, None] for (_, _, _, loop), client in _client_cache.items(): if loop and loop.is_closed():