|
1 | 1 | import os |
2 | | -from typing import Any, Dict, List, Optional, Type |
| 2 | +from typing import Any, Dict, List, Optional, Type, TypeVar, Union, overload |
| 3 | +from urllib.parse import urlparse |
3 | 4 | from warnings import warn |
4 | 5 |
|
5 | 6 | from redis import Redis, RedisCluster |
|
11 | 12 | from redis.asyncio.connection import SSLConnection as AsyncSSLConnection |
12 | 13 | from redis.connection import SSLConnection |
13 | 14 | from redis.exceptions import ResponseError |
| 15 | +from redis.sentinel import Sentinel |
14 | 16 |
|
15 | 17 | from redisvl import __version__ |
16 | 18 | from redisvl.redis.constants import REDIS_URL_ENV_VAR |
@@ -198,6 +200,9 @@ def parse_attrs(attrs, field_type=None): |
198 | 200 | } |
199 | 201 |
|
200 | 202 |
|
| 203 | +T = TypeVar("T", Redis, AsyncRedis) |
| 204 | + |
| 205 | + |
201 | 206 | class RedisConnectionFactory: |
202 | 207 | """Builds connections to a Redis database, supporting both synchronous and |
203 | 208 | asynchronous clients. |
@@ -259,7 +264,9 @@ def get_redis_connection( |
259 | 264 | variable is not set. |
260 | 265 | """ |
261 | 266 | url = redis_url or get_address_from_env() |
262 | | - if is_cluster_url(url, **kwargs): |
| 267 | + if url.startswith("redis+sentinel"): |
| 268 | + client = RedisConnectionFactory._redis_sentinel_client(url, Redis, **kwargs) |
| 269 | + elif is_cluster_url(url, **kwargs): |
263 | 270 | client = RedisCluster.from_url(url, **kwargs) |
264 | 271 | else: |
265 | 272 | client = Redis.from_url(url, **kwargs) |
@@ -299,7 +306,11 @@ async def _get_aredis_connection( |
299 | 306 | """ |
300 | 307 | url = url or get_address_from_env() |
301 | 308 |
|
302 | | - if is_cluster_url(url, **kwargs): |
| 309 | + if url.startswith("redis+sentinel"): |
| 310 | + client = RedisConnectionFactory._redis_sentinel_client( |
| 311 | + url, AsyncRedis, **kwargs |
| 312 | + ) |
| 313 | + elif is_cluster_url(url, **kwargs): |
303 | 314 | client = AsyncRedisCluster.from_url(url, **kwargs) |
304 | 315 | else: |
305 | 316 | client = AsyncRedis.from_url(url, **kwargs) |
@@ -340,6 +351,10 @@ def get_async_redis_connection( |
340 | 351 | DeprecationWarning, |
341 | 352 | ) |
342 | 353 | url = url or get_address_from_env() |
| 354 | + if url.startswith("redis+sentinel"): |
| 355 | + return RedisConnectionFactory._redis_sentinel_client( |
| 356 | + url, AsyncRedis, **kwargs |
| 357 | + ) |
343 | 358 | return AsyncRedis.from_url(url, **kwargs) |
344 | 359 |
|
345 | 360 | @staticmethod |
@@ -446,3 +461,60 @@ async def validate_async_redis( |
446 | 461 | await redis_client.echo(_lib_name) |
447 | 462 |
|
448 | 463 | # Module validation removed - operations will fail naturally if modules are missing |
| 464 | + |
| 465 | + @staticmethod |
| 466 | + @overload |
| 467 | + def _redis_sentinel_client( |
| 468 | + redis_url: str, redis_class: type[Redis], **kwargs: Any |
| 469 | + ) -> Redis: ... |
| 470 | + |
| 471 | + @staticmethod |
| 472 | + @overload |
| 473 | + def _redis_sentinel_client( |
| 474 | + redis_url: str, redis_class: type[AsyncRedis], **kwargs: Any |
| 475 | + ) -> AsyncRedis: ... |
| 476 | + |
| 477 | + @staticmethod |
| 478 | + def _redis_sentinel_client( |
| 479 | + redis_url: str, redis_class: Union[type[Redis], type[AsyncRedis]], **kwargs: Any |
| 480 | + ) -> Union[Redis, AsyncRedis]: |
| 481 | + sentinel_list, service_name, db, username, password = ( |
| 482 | + RedisConnectionFactory._parse_sentinel_url(redis_url) |
| 483 | + ) |
| 484 | + |
| 485 | + sentinel_kwargs = {} |
| 486 | + if username: |
| 487 | + sentinel_kwargs["username"] = username |
| 488 | + kwargs["username"] = username |
| 489 | + if password: |
| 490 | + sentinel_kwargs["password"] = password |
| 491 | + kwargs["password"] = password |
| 492 | + if db: |
| 493 | + kwargs["db"] = db |
| 494 | + |
| 495 | + sentinel = Sentinel(sentinel_list, sentinel_kwargs=sentinel_kwargs, **kwargs) |
| 496 | + return sentinel.master_for(service_name, redis_class=redis_class, **kwargs) |
| 497 | + |
| 498 | + @staticmethod |
| 499 | + def _parse_sentinel_url(url: str) -> tuple: |
| 500 | + parsed_url = urlparse(url) |
| 501 | + hosts_part = parsed_url.netloc.split("@")[-1] |
| 502 | + sentinel_hosts = hosts_part.split(",") |
| 503 | + |
| 504 | + sentinel_list = [] |
| 505 | + for host in sentinel_hosts: |
| 506 | + host_parts = host.split(":") |
| 507 | + if len(host_parts) == 2: |
| 508 | + sentinel_list.append((host_parts[0], int(host_parts[1]))) |
| 509 | + else: |
| 510 | + sentinel_list.append((host_parts[0], 26379)) |
| 511 | + |
| 512 | + service_name = "mymaster" |
| 513 | + db = None |
| 514 | + if parsed_url.path: |
| 515 | + path_parts = parsed_url.path.split("/") |
| 516 | + service_name = path_parts[1] or "mymaster" |
| 517 | + if len(path_parts) > 2: |
| 518 | + db = path_parts[2] |
| 519 | + |
| 520 | + return sentinel_list, service_name, db, parsed_url.username, parsed_url.password |
0 commit comments