Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Added

- Added retry with back-off logic for Redis related functions. [#528](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/528)

### Changed

### Fixed
Expand Down
1 change: 1 addition & 0 deletions stac_fastapi/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"jsonschema~=4.0.0",
"slowapi~=0.1.9",
"redis==6.4.0",
"retry==0.9.2",
]

[project.urls]
Expand Down
125 changes: 85 additions & 40 deletions stac_fastapi/core/stac_fastapi/core/redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import json
import logging
from typing import List, Optional, Tuple
from functools import wraps
from typing import Callable, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
from redis import asyncio as aioredis
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import TimeoutError as RedisTimeoutError
from retry import retry # type: ignore

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,65 +147,104 @@ def validate_self_link_ttl_standalone(cls, v: int) -> int:
return v


class RedisRetrySettings(BaseSettings):
Copy link
Collaborator

@Gomez324 Gomez324 Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can create a parent class that both RedisSentinelSettings and RedisSettings will inherit from. This way we won't need to keep another class in memory. You can also move the common elements from the child classes to the parent class

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""Configuration for Redis retry wrapper."""

redis_query_retries_num: int = Field(
default=3, alias="REDIS_QUERY_RETRIES_NUM", gt=0
)
redis_query_initial_delay: float = Field(
default=1.0, alias="REDIS_QUERY_INITIAL_DELAY", gt=0
)
redis_query_backoff: float = Field(default=2.0, alias="REDIS_QUERY_BACKOFF", gt=1)


# Configure only one Redis configuration
sentinel_settings = RedisSentinelSettings()
standalone_settings = RedisSettings()
retry_settings = RedisRetrySettings()


async def connect_redis() -> Optional[aioredis.Redis]:
def redis_retry(func: Callable) -> Callable:
"""Wrap function in retry with back-off logic."""

@wraps(func)
@retry(
exceptions=(RedisConnectionError, RedisTimeoutError),
tries=retry_settings.redis_query_retries_num,
delay=retry_settings.redis_query_initial_delay,
backoff=retry_settings.redis_query_backoff,
logger=logger,
)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)

return wrapper


@redis_retry
async def _connect_redis_internal() -> Optional[aioredis.Redis]:
"""Return a Redis connection Redis or Redis Sentinel."""
try:
if sentinel_settings.REDIS_SENTINEL_HOSTS:
sentinel_nodes = sentinel_settings.get_sentinel_nodes()
sentinel = Sentinel(
sentinel_nodes,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
)
if sentinel_settings.REDIS_SENTINEL_HOSTS:
sentinel_nodes = sentinel_settings.get_sentinel_nodes()
sentinel = Sentinel(
sentinel_nodes,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
)

redis = sentinel.master_for(
service_name=sentinel_settings.REDIS_SENTINEL_MASTER_NAME,
db=sentinel_settings.REDIS_DB,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=sentinel_settings.REDIS_RETRY_TIMEOUT,
client_name=sentinel_settings.REDIS_CLIENT_NAME,
max_connections=sentinel_settings.REDIS_MAX_CONNECTIONS,
health_check_interval=sentinel_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
logger.info("Connected to Redis Sentinel")
redis = sentinel.master_for(
service_name=sentinel_settings.REDIS_SENTINEL_MASTER_NAME,
db=sentinel_settings.REDIS_DB,
decode_responses=sentinel_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=sentinel_settings.REDIS_RETRY_TIMEOUT,
client_name=sentinel_settings.REDIS_CLIENT_NAME,
max_connections=sentinel_settings.REDIS_MAX_CONNECTIONS,
health_check_interval=sentinel_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
logger.info("Connected to Redis Sentinel")

elif standalone_settings.REDIS_HOST:
pool = aioredis.ConnectionPool(
host=standalone_settings.REDIS_HOST,
port=standalone_settings.REDIS_PORT,
db=standalone_settings.REDIS_DB,
max_connections=standalone_settings.REDIS_MAX_CONNECTIONS,
decode_responses=standalone_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=standalone_settings.REDIS_RETRY_TIMEOUT,
health_check_interval=standalone_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
redis = aioredis.Redis(
connection_pool=pool, client_name=standalone_settings.REDIS_CLIENT_NAME
)
logger.info("Connected to Redis")
else:
logger.warning("No Redis configuration found")
return None

elif standalone_settings.REDIS_HOST:
pool = aioredis.ConnectionPool(
host=standalone_settings.REDIS_HOST,
port=standalone_settings.REDIS_PORT,
db=standalone_settings.REDIS_DB,
max_connections=standalone_settings.REDIS_MAX_CONNECTIONS,
decode_responses=standalone_settings.REDIS_DECODE_RESPONSES,
retry_on_timeout=standalone_settings.REDIS_RETRY_TIMEOUT,
health_check_interval=standalone_settings.REDIS_HEALTH_CHECK_INTERVAL,
)
redis = aioredis.Redis(
connection_pool=pool, client_name=standalone_settings.REDIS_CLIENT_NAME
)
logger.info("Connected to Redis")
else:
logger.warning("No Redis configuration found")
return None
return redis

return redis

async def connect_redis() -> Optional[aioredis.Redis]:
"""Handle Redis connection."""
try:
return await _connect_redis_internal()
except (
aioredis.ConnectionError,
aioredis.TimeoutError,
) as e:
logger.error(f"Redis connection failed after retries: {e}")
except aioredis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return None
except aioredis.AuthenticationError as e:
logger.error(f"Redis authentication error: {e}")
return None
except aioredis.TimeoutError as e:
logger.error(f"Redis timeout error: {e}")
return None
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
return None

return None


def get_redis_key(url: str, token: str) -> str:
"""Create Redis key using URL path and token."""
Expand Down Expand Up @@ -230,6 +273,7 @@ def build_url_with_token(base_url: str, token: str) -> str:
)


@redis_retry
async def save_prev_link(
redis: aioredis.Redis, next_url: str, current_url: str, next_token: str
) -> None:
Expand All @@ -243,6 +287,7 @@ async def save_prev_link(
await redis.setex(key, ttl_seconds, current_url)


@redis_retry
async def get_prev_link(
redis: aioredis.Redis, current_url: str, current_token: str
) -> Optional[str]:
Expand Down
98 changes: 98 additions & 0 deletions stac_fastapi/tests/redis/test_redis_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from redis.exceptions import ConnectionError as RedisConnectionError

import stac_fastapi.core.redis_utils as redis_utils
from stac_fastapi.core.redis_utils import connect_redis, get_prev_link, save_prev_link


Expand Down Expand Up @@ -46,3 +48,99 @@ async def test_redis_utils_functions():
redis, "http://mywebsite.com/search", "non_existent_token"
)
assert non_existent is None


@pytest.mark.asyncio
async def test_redis_retry_retries_until_success(monkeypatch):
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_retries_num", 3, raising=False
)
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_initial_delay", 0, raising=False
)
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_backoff", 2.0, raising=False
)

captured_kwargs = {}

def fake_retry(**kwargs):
captured_kwargs.update(kwargs)

def decorator(func):
async def wrapped(*args, **inner_kwargs):
attempts = 0
while True:
try:
attempts += 1
return await func(*args, **inner_kwargs)
except kwargs["exceptions"] as exc:
if attempts >= kwargs["tries"]:
raise exc
continue

return wrapped

return decorator

monkeypatch.setattr(redis_utils, "retry", fake_retry)

call_counter = {"count": 0}

@redis_utils.redis_retry
async def flaky() -> str:
call_counter["count"] += 1
if call_counter["count"] < 3:
raise RedisConnectionError("transient failure")
return "success"

result = await flaky()

assert result == "success"
assert call_counter["count"] == 3
assert (
captured_kwargs["tries"] == redis_utils.retry_settings.redis_query_retries_num
)
assert (
captured_kwargs["delay"] == redis_utils.retry_settings.redis_query_initial_delay
)
assert captured_kwargs["backoff"] == redis_utils.retry_settings.redis_query_backoff


@pytest.mark.asyncio
async def test_redis_retry_raises_after_exhaustion(monkeypatch):
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_retries_num", 3, raising=False
)
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_initial_delay", 0, raising=False
)
monkeypatch.setattr(
redis_utils.retry_settings, "redis_query_backoff", 2.0, raising=False
)

def fake_retry(**kwargs):
def decorator(func):
async def wrapped(*args, **inner_kwargs):
attempts = 0
while True:
try:
attempts += 1
return await func(*args, **inner_kwargs)
except kwargs["exceptions"] as exc:
if attempts >= kwargs["tries"]:
raise exc
continue

return wrapped

return decorator

monkeypatch.setattr(redis_utils, "retry", fake_retry)

@redis_utils.redis_retry
async def always_fail() -> str:
raise RedisConnectionError("pernament failure")

with pytest.raises(RedisConnectionError):
await always_fail()