Skip to content

Commit 2537cb8

Browse files
committed
test: Updated CredentialProvider test infrastructure (#3502)
* test: Updated CredentialProvider test infrastructure * Added linter exclusion * Updated dev dependency * Codestyle fixes * Updated async test infra * Added missing constant
1 parent 03bc125 commit 2537cb8

File tree

3 files changed

+124
-74
lines changed

3 files changed

+124
-74
lines changed

dev_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ uvloop
1616
vulture>=2.3.0
1717
wheel>=0.30.0
1818
numpy>=1.24.0
19-
redis-entraid==0.1.0b1
19+
redis-entraid==0.3.0b1

tests/conftest.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from datetime import datetime, timezone
77
from enum import Enum
8-
from typing import Callable, TypeVar
8+
from typing import Callable, TypeVar, Union
99
from unittest import mock
1010
from unittest.mock import Mock
1111
from urllib.parse import urlparse
@@ -17,6 +17,7 @@
1717
from redis import Sentinel
1818
from redis.auth.idp import IdentityProviderInterface
1919
from redis.auth.token import JWToken
20+
from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
2021
from redis.backoff import NoBackoff
2122
from redis.cache import (
2223
CacheConfig,
@@ -29,12 +30,21 @@
2930
from redis.credentials import CredentialProvider
3031
from redis.exceptions import RedisClusterException
3132
from redis.retry import Retry
32-
from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig
33+
from redis_entraid.cred_provider import (
34+
DEFAULT_DELAY_IN_MS,
35+
DEFAULT_EXPIRATION_REFRESH_RATIO,
36+
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
37+
DEFAULT_MAX_ATTEMPTS,
38+
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
39+
EntraIdCredentialsProvider,
40+
)
3341
from redis_entraid.identity_provider import (
3442
ManagedIdentityIdType,
43+
ManagedIdentityProviderConfig,
3544
ManagedIdentityType,
36-
create_provider_from_managed_identity,
37-
create_provider_from_service_principal,
45+
ServicePrincipalIdentityProviderConfig,
46+
_create_provider_from_managed_identity,
47+
_create_provider_from_service_principal,
3848
)
3949
from tests.ssl_utils import get_tls_certificates
4050

@@ -623,41 +633,58 @@ def identity_provider(request) -> IdentityProviderInterface:
623633
return mock_identity_provider()
624634

625635
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
636+
config = get_identity_provider_config(request=request)
626637

627638
if auth_type == "MANAGED_IDENTITY":
628-
return _get_managed_identity_provider(request)
639+
return _create_provider_from_managed_identity(config)
640+
641+
return _create_provider_from_service_principal(config)
629642

630-
return _get_service_principal_provider(request)
631643

644+
def get_identity_provider_config(
645+
request,
646+
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
647+
if hasattr(request, "param"):
648+
kwargs = request.param.get("idp_kwargs", {})
649+
else:
650+
kwargs = {}
651+
652+
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
632653

633-
def _get_managed_identity_provider(request):
634-
authority = os.getenv("AZURE_AUTHORITY")
654+
if auth_type == AuthType.MANAGED_IDENTITY:
655+
return _get_managed_identity_provider_config(request)
656+
657+
return _get_service_principal_provider_config(request)
658+
659+
660+
def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
635661
resource = os.getenv("AZURE_RESOURCE")
636-
id_value = os.getenv("AZURE_ID_VALUE", None)
662+
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
637663

638664
if hasattr(request, "param"):
639665
kwargs = request.param.get("idp_kwargs", {})
640666
else:
641667
kwargs = {}
642668

643669
identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
644-
id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID)
670+
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
645671

646-
return create_provider_from_managed_identity(
672+
return ManagedIdentityProviderConfig(
647673
identity_type=identity_type,
648674
resource=resource,
649675
id_type=id_type,
650676
id_value=id_value,
651-
authority=authority,
652-
**kwargs,
677+
kwargs=kwargs,
653678
)
654679

655680

656-
def _get_service_principal_provider(request):
681+
def _get_service_principal_provider_config(
682+
request,
683+
) -> ServicePrincipalIdentityProviderConfig:
657684
client_id = os.getenv("AZURE_CLIENT_ID")
658685
client_credential = os.getenv("AZURE_CLIENT_SECRET")
659-
authority = os.getenv("AZURE_AUTHORITY")
660-
scopes = os.getenv("AZURE_REDIS_SCOPES", [])
686+
tenant_id = os.getenv("AZURE_TENANT_ID")
687+
scopes = os.getenv("AZURE_REDIS_SCOPES", None)
661688

662689
if hasattr(request, "param"):
663690
kwargs = request.param.get("idp_kwargs", {})
@@ -671,14 +698,14 @@ def _get_service_principal_provider(request):
671698
if isinstance(scopes, str):
672699
scopes = scopes.split(",")
673700

674-
return create_provider_from_service_principal(
701+
return ServicePrincipalIdentityProviderConfig(
675702
client_id=client_id,
676703
client_credential=client_credential,
677704
scopes=scopes,
678705
timeout=timeout,
679706
token_kwargs=token_kwargs,
680-
authority=authority,
681-
**kwargs,
707+
tenant_id=tenant_id,
708+
app_kwargs=kwargs,
682709
)
683710

684711

@@ -690,31 +717,29 @@ def get_credential_provider(request) -> CredentialProvider:
690717
return cred_provider_class(**cred_provider_kwargs)
691718

692719
idp = identity_provider(request)
693-
initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0)
694-
block_for_initial = cred_provider_kwargs.get("block_for_initial", False)
695720
expiration_refresh_ratio = cred_provider_kwargs.get(
696-
"expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO
721+
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
697722
)
698723
lower_refresh_bound_millis = cred_provider_kwargs.get(
699-
"lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS
700-
)
701-
max_attempts = cred_provider_kwargs.get(
702-
"max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS
724+
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
703725
)
704-
delay_in_ms = cred_provider_kwargs.get(
705-
"delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS
726+
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
727+
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
728+
729+
token_mgr_config = TokenManagerConfig(
730+
expiration_refresh_ratio=expiration_refresh_ratio,
731+
lower_refresh_bound_millis=lower_refresh_bound_millis,
732+
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa
733+
retry_policy=RetryPolicy(
734+
max_attempts=max_attempts,
735+
delay_in_ms=delay_in_ms,
736+
),
706737
)
707738

708-
auth_config = TokenAuthConfig(idp)
709-
auth_config.expiration_refresh_ratio = expiration_refresh_ratio
710-
auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis
711-
auth_config.max_attempts = max_attempts
712-
auth_config.delay_in_ms = delay_in_ms
713-
714739
return EntraIdCredentialsProvider(
715-
config=auth_config,
716-
initial_delay_in_ms=initial_delay_in_ms,
717-
block_for_initial=block_for_initial,
740+
identity_provider=idp,
741+
token_manager_config=token_mgr_config,
742+
initial_delay_in_ms=delay_in_ms,
718743
)
719744

720745

tests/test_asyncio/conftest.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,24 @@
1717
from redis.asyncio.retry import Retry
1818
from redis.auth.idp import IdentityProviderInterface
1919
from redis.auth.token import JWToken
20+
from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
2021
from redis.backoff import NoBackoff
2122
from redis.credentials import CredentialProvider
22-
from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig
23+
from redis_entraid.cred_provider import (
24+
DEFAULT_DELAY_IN_MS,
25+
DEFAULT_EXPIRATION_REFRESH_RATIO,
26+
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
27+
DEFAULT_MAX_ATTEMPTS,
28+
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
29+
EntraIdCredentialsProvider,
30+
)
2331
from redis_entraid.identity_provider import (
2432
ManagedIdentityIdType,
33+
ManagedIdentityProviderConfig,
2534
ManagedIdentityType,
26-
create_provider_from_managed_identity,
27-
create_provider_from_service_principal,
35+
ServicePrincipalIdentityProviderConfig,
36+
_create_provider_from_managed_identity,
37+
_create_provider_from_service_principal,
2838
)
2939
from tests.conftest import REDIS_INFO
3040

@@ -255,41 +265,58 @@ def identity_provider(request) -> IdentityProviderInterface:
255265
return mock_identity_provider()
256266

257267
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
268+
config = get_identity_provider_config(request=request)
258269

259270
if auth_type == "MANAGED_IDENTITY":
260-
return _get_managed_identity_provider(request)
271+
return _create_provider_from_managed_identity(config)
272+
273+
return _create_provider_from_service_principal(config)
274+
275+
276+
def get_identity_provider_config(
277+
request,
278+
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
279+
if hasattr(request, "param"):
280+
kwargs = request.param.get("idp_kwargs", {})
281+
else:
282+
kwargs = {}
261283

262-
return _get_service_principal_provider(request)
284+
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
285+
286+
if auth_type == AuthType.MANAGED_IDENTITY:
287+
return _get_managed_identity_provider_config(request)
263288

289+
return _get_service_principal_provider_config(request)
264290

265-
def _get_managed_identity_provider(request):
266-
authority = os.getenv("AZURE_AUTHORITY")
291+
292+
def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
267293
resource = os.getenv("AZURE_RESOURCE")
268-
id_value = os.getenv("AZURE_ID_VALUE", None)
294+
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
269295

270296
if hasattr(request, "param"):
271297
kwargs = request.param.get("idp_kwargs", {})
272298
else:
273299
kwargs = {}
274300

275301
identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
276-
id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID)
302+
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
277303

278-
return create_provider_from_managed_identity(
304+
return ManagedIdentityProviderConfig(
279305
identity_type=identity_type,
280306
resource=resource,
281307
id_type=id_type,
282308
id_value=id_value,
283-
authority=authority,
284-
**kwargs,
309+
kwargs=kwargs,
285310
)
286311

287312

288-
def _get_service_principal_provider(request):
313+
def _get_service_principal_provider_config(
314+
request,
315+
) -> ServicePrincipalIdentityProviderConfig:
289316
client_id = os.getenv("AZURE_CLIENT_ID")
290317
client_credential = os.getenv("AZURE_CLIENT_SECRET")
291-
authority = os.getenv("AZURE_AUTHORITY")
292-
scopes = os.getenv("AZURE_REDIS_SCOPES", [])
318+
tenant_id = os.getenv("AZURE_TENANT_ID")
319+
scopes = os.getenv("AZURE_REDIS_SCOPES", None)
293320

294321
if hasattr(request, "param"):
295322
kwargs = request.param.get("idp_kwargs", {})
@@ -303,14 +330,14 @@ def _get_service_principal_provider(request):
303330
if isinstance(scopes, str):
304331
scopes = scopes.split(",")
305332

306-
return create_provider_from_service_principal(
333+
return ServicePrincipalIdentityProviderConfig(
307334
client_id=client_id,
308335
client_credential=client_credential,
309336
scopes=scopes,
310337
timeout=timeout,
311338
token_kwargs=token_kwargs,
312-
authority=authority,
313-
**kwargs,
339+
tenant_id=tenant_id,
340+
app_kwargs=kwargs,
314341
)
315342

316343

@@ -322,31 +349,29 @@ def get_credential_provider(request) -> CredentialProvider:
322349
return cred_provider_class(**cred_provider_kwargs)
323350

324351
idp = identity_provider(request)
325-
initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0)
326-
block_for_initial = cred_provider_kwargs.get("block_for_initial", False)
327352
expiration_refresh_ratio = cred_provider_kwargs.get(
328-
"expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO
353+
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
329354
)
330355
lower_refresh_bound_millis = cred_provider_kwargs.get(
331-
"lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS
332-
)
333-
max_attempts = cred_provider_kwargs.get(
334-
"max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS
356+
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
335357
)
336-
delay_in_ms = cred_provider_kwargs.get(
337-
"delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS
358+
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
359+
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
360+
361+
token_mgr_config = TokenManagerConfig(
362+
expiration_refresh_ratio=expiration_refresh_ratio,
363+
lower_refresh_bound_millis=lower_refresh_bound_millis,
364+
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa
365+
retry_policy=RetryPolicy(
366+
max_attempts=max_attempts,
367+
delay_in_ms=delay_in_ms,
368+
),
338369
)
339370

340-
auth_config = TokenAuthConfig(idp)
341-
auth_config.expiration_refresh_ratio = expiration_refresh_ratio
342-
auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis
343-
auth_config.max_attempts = max_attempts
344-
auth_config.delay_in_ms = delay_in_ms
345-
346371
return EntraIdCredentialsProvider(
347-
config=auth_config,
348-
initial_delay_in_ms=initial_delay_in_ms,
349-
block_for_initial=block_for_initial,
372+
identity_provider=idp,
373+
token_manager_config=token_mgr_config,
374+
initial_delay_in_ms=delay_in_ms,
350375
)
351376

352377

0 commit comments

Comments
 (0)