Skip to content

Commit 7b958e0

Browse files
committed
A CredentialsProvider class has been added to allow the user to add his own provider for password rotation
1 parent e6cd4fd commit 7b958e0

File tree

7 files changed

+239
-15
lines changed

7 files changed

+239
-15
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* Added dynaminc_startup_nodes configuration to RedisCluster
1515
* Fix reusing the old nodes' connections when cluster topology refresh is being done
1616
* Fix RedisCluster to immediately raise AuthenticationError without a retry
17+
* Added CredentialsProvider class to support password rotation
1718
* 4.1.3 (Feb 8, 2022)
1819
* Fix flushdb and flushall (#1926)
1920
* Add redis5 and redis4 dockers (#1871)

docs/examples/connection_examples.ipynb

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,92 @@
9797
"user_connection.ping()"
9898
]
9999
},
100+
{
101+
"cell_type": "markdown",
102+
"metadata": {
103+
"collapsed": false
104+
},
105+
"source": [
106+
"## Connecting to a redis instance with AWS Secrets Manager credentials provider."
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": null,
112+
"metadata": {
113+
"collapsed": false,
114+
"pycharm": {
115+
"name": "#%%\n"
116+
}
117+
},
118+
"outputs": [],
119+
"source": [
120+
"import redis\n",
121+
"import boto3\n",
122+
"import json\n",
123+
"import cachetools.func\n",
124+
"\n",
125+
"sm_client = boto3.client('secretsmanager')\n",
126+
" \n",
127+
"def sm_auth_provider(secret_id, version_id=None, version_stage='AWSCURRENT'):\n",
128+
" @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n",
129+
" def get_sm_user_credentials(secret_id, version_id, version_stage):\n",
130+
" secret = sm_client.get_secret_value(secret_id, version_id)\n",
131+
" return json.loads(secret['SecretString'])\n",
132+
" creds = get_sm_user_credentials(secret_id, version_id, version_stage)\n",
133+
" return creds['username'], creds['password']\n",
134+
"\n",
135+
"secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n",
136+
"creds_provider = redis.CredentialsProvider(supplier=sm_auth_provider, secret_id=secret_id)\n",
137+
"user_connection = redis.Redis(host=\"localhost\", port=6379, credentials_provider=creds_provider)\n",
138+
"user_connection.ping()"
139+
]
140+
},
141+
{
142+
"cell_type": "markdown",
143+
"metadata": {},
144+
"source": [
145+
"## Connecting to a redis instance with ElastiCache IAM credentials provider."
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": 4,
151+
"metadata": {},
152+
"outputs": [
153+
{
154+
"data": {
155+
"text/plain": [
156+
"True"
157+
]
158+
},
159+
"execution_count": 4,
160+
"metadata": {},
161+
"output_type": "execute_result"
162+
}
163+
],
164+
"source": [
165+
"import redis\n",
166+
"import boto3\n",
167+
"import cachetools.func\n",
168+
"\n",
169+
"ec_client = boto3.client('elasticache')\n",
170+
"\n",
171+
"def iam_auth_provider(user, endpoint, port=6379, region=\"us-east-1\"):\n",
172+
" @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n",
173+
" def get_iam_auth_token(user, endpoint, port, region):\n",
174+
" return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n",
175+
" iam_auth_token = get_iam_auth_token(endpoint, port, user, region)\n",
176+
" return iam_auth_token\n",
177+
"\n",
178+
"username = \"barshaul\"\n",
179+
"endpoint = \"test-001.use1.cache.amazonaws.com\"\n",
180+
"creds_provider = redis.CredentialsProvider(supplier=iam_auth_provider, user=username,\n",
181+
" endpoint=endpoint)\n",
182+
"user_connection = redis.Redis(host=endpoint, port=6379, credentials_provider=creds_provider)\n",
183+
"user_connection.ping()"
184+
]
185+
},
100186
{
101187
"cell_type": "markdown",
102188
"metadata": {},
@@ -176,4 +262,4 @@
176262
},
177263
"nbformat": 4,
178264
"nbformat_minor": 2
179-
}
265+
}

redis/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
BlockingConnectionPool,
77
Connection,
88
ConnectionPool,
9+
CredentialsProvider,
910
SSLConnection,
1011
UnixDomainSocketConnection,
1112
)
@@ -62,6 +63,7 @@ def int_or_str(value):
6263
"Connection",
6364
"ConnectionError",
6465
"ConnectionPool",
66+
"CredentialsProvider",
6567
"DataError",
6668
"from_url",
6769
"InvalidResponse",

redis/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,7 @@ def __init__(
936936
username=None,
937937
retry=None,
938938
redis_connect_func=None,
939+
credentials_provider=None,
939940
):
940941
"""
941942
Initialize a new Redis client.
@@ -977,6 +978,7 @@ def __init__(
977978
"health_check_interval": health_check_interval,
978979
"client_name": client_name,
979980
"redis_connect_func": redis_connect_func,
981+
"credentials_provider": credentials_provider,
980982
}
981983
# based on input, setup appropriate connection args
982984
if unix_socket_path is not None:

redis/cluster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def parse_cluster_shards(resp, **options):
125125
"connection_class",
126126
"connection_pool",
127127
"client_name",
128+
"credentials_provider",
128129
"db",
129130
"decode_responses",
130131
"encoding",

redis/connection.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,42 @@ def read_response(self, disable_decoding=False):
500500
DefaultParser = PythonParser
501501

502502

503+
class CredentialsProvider:
504+
def __init__(self, username="", password="", supplier=None, *args, **kwargs):
505+
"""
506+
Initialize a new Credentials Provider.
507+
:param supplier: a supplier function that returns the username and password.
508+
def supplier(arg1, arg2, ...) -> (username, password)
509+
For examples see examples/connection_examples.ipynb
510+
:param args: arguments to pass to the supplier function
511+
:param kwargs: keyword arguments to pass to the supplier function
512+
"""
513+
self.username = username
514+
self.password = password
515+
self.supplier = supplier
516+
self.args = args
517+
self.kwargs = kwargs
518+
519+
def get_credentials(self):
520+
if self.supplier:
521+
self.username, self.password = self.supplier(*self.args, **self.kwargs)
522+
if self.username:
523+
auth_args = (self.username, self.password or "")
524+
else:
525+
auth_args = (self.password,)
526+
return auth_args
527+
528+
def get_password(self, call_supplier=True):
529+
if call_supplier and self.supplier:
530+
self.username, self.password = self.supplier(*self.args, **self.kwargs)
531+
return self.password
532+
533+
def get_username(self, call_supplier=True):
534+
if call_supplier and self.supplier:
535+
self.username, self.password = self.supplier(*self.args, **self.kwargs)
536+
return self.username
537+
538+
503539
class Connection:
504540
"Manages TCP communication to and from a Redis server"
505541

@@ -526,6 +562,7 @@ def __init__(
526562
username=None,
527563
retry=None,
528564
redis_connect_func=None,
565+
credentials_provider=None,
529566
):
530567
"""
531568
Initialize a new Connection.
@@ -538,9 +575,10 @@ def __init__(
538575
self.host = host
539576
self.port = int(port)
540577
self.db = db
541-
self.username = username
542578
self.client_name = client_name
543-
self.password = password
579+
self.credentials_provider = credentials_provider
580+
if not self.credentials_provider and (username or password):
581+
self.credentials_provider = CredentialsProvider(username, password)
544582
self.socket_timeout = socket_timeout
545583
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
546584
self.socket_keepalive = socket_keepalive
@@ -699,12 +737,9 @@ def on_connect(self):
699737
"Initialize the connection, authenticate and select a database"
700738
self._parser.on_connect(self)
701739

702-
# if username and/or password are set, authenticate
703-
if self.username or self.password:
704-
if self.username:
705-
auth_args = (self.username, self.password or "")
706-
else:
707-
auth_args = (self.password,)
740+
# if credentials provider is set, authenticate
741+
if self.credentials_provider:
742+
auth_args = self.credentials_provider.get_credentials()
708743
# avoid checking health here -- PING will fail if we try
709744
# to check the health prior to the AUTH
710745
self.send_command("AUTH", *auth_args, check_health=False)
@@ -716,7 +751,11 @@ def on_connect(self):
716751
# server seems to be < 6.0.0 which expects a single password
717752
# arg. retry auth with just the password.
718753
# https://github.com/andymccurdy/redis-py/issues/1274
719-
self.send_command("AUTH", self.password, check_health=False)
754+
self.send_command(
755+
"AUTH",
756+
self.credentials_provider.get_password(),
757+
check_health=False,
758+
)
720759
auth_response = self.read_response()
721760

722761
if str_if_bytes(auth_response) != "OK":
@@ -1074,6 +1113,7 @@ def __init__(
10741113
client_name=None,
10751114
retry=None,
10761115
redis_connect_func=None,
1116+
credentials_provider=None,
10771117
):
10781118
"""
10791119
Initialize a new UnixDomainSocketConnection.
@@ -1085,9 +1125,10 @@ def __init__(
10851125
self.pid = os.getpid()
10861126
self.path = path
10871127
self.db = db
1088-
self.username = username
10891128
self.client_name = client_name
1090-
self.password = password
1129+
self.credentials_provider = credentials_provider
1130+
if not self.credentials_provider and (username or password):
1131+
self.credentials_provider = CredentialsProvider(username, password)
10911132
self.socket_timeout = socket_timeout
10921133
self.retry_on_timeout = retry_on_timeout
10931134
if retry_on_error is SENTINEL:

tests/test_connection.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
1+
import random
12
import socket
3+
import string
24
import types
35
from unittest import mock
46
from unittest.mock import patch
57

68
import pytest
79

10+
import redis
811
from redis.backoff import NoBackoff
9-
from redis.connection import Connection
10-
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
12+
from redis.connection import Connection, CredentialsProvider
13+
from redis.exceptions import (
14+
ConnectionError,
15+
InvalidResponse,
16+
ResponseError,
17+
TimeoutError,
18+
)
1119
from redis.retry import Retry
1220
from redis.utils import HIREDIS_AVAILABLE
1321

14-
from .conftest import skip_if_server_version_lt
22+
from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt
1523

1624

1725
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@@ -122,3 +130,86 @@ def test_connect_timeout_error_without_retry(self):
122130
assert conn._connect.call_count == 1
123131
assert str(e.value) == "Timeout connecting to server"
124132
self.clear(conn)
133+
134+
135+
class TestCredentialsProvider:
136+
@skip_if_redis_enterprise()
137+
def test_credentials_provider_without_supplier(self, r, request):
138+
# first, test for default user (`username` is supposed to be optional)
139+
default_username = "default"
140+
temp_pass = "temp_pass"
141+
creds_provider = CredentialsProvider(default_username, temp_pass)
142+
r.config_set("requirepass", temp_pass)
143+
creds = creds_provider.get_credentials()
144+
assert r.auth(creds[1], creds[0]) is True
145+
assert r.auth(creds_provider.get_password()) is True
146+
147+
# test for other users
148+
username = "redis-py-auth"
149+
password = "strong_password"
150+
151+
def teardown():
152+
try:
153+
r.auth(temp_pass)
154+
except ResponseError:
155+
r.auth("default", "")
156+
r.config_set("requirepass", "")
157+
r.acl_deluser(username)
158+
159+
request.addfinalizer(teardown)
160+
161+
assert r.acl_setuser(
162+
username,
163+
enabled=True,
164+
passwords=["+" + password],
165+
keys="~*",
166+
commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"],
167+
)
168+
169+
creds_provider2 = CredentialsProvider(username, password)
170+
r2 = _get_client(
171+
redis.Redis, request, flushdb=False, credentials_provider=creds_provider2
172+
)
173+
174+
assert r2.ping() is True
175+
176+
@skip_if_redis_enterprise()
177+
def test_credentials_provider_with_supplier(self, r, request):
178+
import functools
179+
180+
@functools.lru_cache(maxsize=10)
181+
def auth_supplier(user, endpoint):
182+
def get_random_string(length):
183+
letters = string.ascii_lowercase
184+
result_str = "".join(random.choice(letters) for i in range(length))
185+
return result_str
186+
187+
auth_token = get_random_string(5) + user + "_" + endpoint
188+
return user, auth_token
189+
190+
username = "redis-py-auth"
191+
creds_provider = CredentialsProvider(
192+
supplier=auth_supplier,
193+
user=username,
194+
endpoint="localhost",
195+
)
196+
password = creds_provider.get_password()
197+
198+
assert r.acl_setuser(
199+
username,
200+
enabled=True,
201+
passwords=["+" + password],
202+
keys="~*",
203+
commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"],
204+
)
205+
206+
def teardown():
207+
r.acl_deluser(username)
208+
209+
request.addfinalizer(teardown)
210+
211+
r2 = _get_client(
212+
redis.Redis, request, flushdb=False, credentials_provider=creds_provider
213+
)
214+
215+
assert r2.ping() is True

0 commit comments

Comments
 (0)