5
5
import time
6
6
from datetime import datetime , timezone
7
7
from enum import Enum
8
- from typing import Callable , TypeVar
8
+ from typing import Callable , TypeVar , Union
9
9
from unittest import mock
10
10
from unittest .mock import Mock
11
11
from urllib .parse import urlparse
17
17
from redis import Sentinel
18
18
from redis .auth .idp import IdentityProviderInterface
19
19
from redis .auth .token import JWToken
20
+ from redis .auth .token_manager import RetryPolicy , TokenManagerConfig
20
21
from redis .backoff import NoBackoff
21
22
from redis .cache import (
22
23
CacheConfig ,
29
30
from redis .credentials import CredentialProvider
30
31
from redis .exceptions import RedisClusterException
31
32
from redis .retry import Retry
32
- from redis_entraid .cred_provider import EntraIdCredentialsProvider , TokenAuthConfig
33
+ from redis_entraid .cred_provider import (
34
+ DEFAULT_DELAY_IN_MS ,
35
+ DEFAULT_EXPIRATION_REFRESH_RATIO ,
36
+ DEFAULT_LOWER_REFRESH_BOUND_MILLIS ,
37
+ DEFAULT_MAX_ATTEMPTS ,
38
+ DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS ,
39
+ EntraIdCredentialsProvider ,
40
+ )
33
41
from redis_entraid .identity_provider import (
34
42
ManagedIdentityIdType ,
43
+ ManagedIdentityProviderConfig ,
35
44
ManagedIdentityType ,
36
- create_provider_from_managed_identity ,
37
- create_provider_from_service_principal ,
45
+ ServicePrincipalIdentityProviderConfig ,
46
+ _create_provider_from_managed_identity ,
47
+ _create_provider_from_service_principal ,
38
48
)
39
49
from tests .ssl_utils import get_tls_certificates
40
50
@@ -623,41 +633,58 @@ def identity_provider(request) -> IdentityProviderInterface:
623
633
return mock_identity_provider ()
624
634
625
635
auth_type = kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
636
+ config = get_identity_provider_config (request = request )
626
637
627
638
if auth_type == "MANAGED_IDENTITY" :
628
- return _get_managed_identity_provider (request )
639
+ return _create_provider_from_managed_identity (config )
640
+
641
+ return _create_provider_from_service_principal (config )
629
642
630
- return _get_service_principal_provider (request )
631
643
644
+ def get_identity_provider_config (
645
+ request ,
646
+ ) -> Union [ManagedIdentityProviderConfig , ServicePrincipalIdentityProviderConfig ]:
647
+ if hasattr (request , "param" ):
648
+ kwargs = request .param .get ("idp_kwargs" , {})
649
+ else :
650
+ kwargs = {}
651
+
652
+ auth_type = kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
632
653
633
- def _get_managed_identity_provider (request ):
634
- authority = os .getenv ("AZURE_AUTHORITY" )
654
+ if auth_type == AuthType .MANAGED_IDENTITY :
655
+ return _get_managed_identity_provider_config (request )
656
+
657
+ return _get_service_principal_provider_config (request )
658
+
659
+
660
+ def _get_managed_identity_provider_config (request ) -> ManagedIdentityProviderConfig :
635
661
resource = os .getenv ("AZURE_RESOURCE" )
636
- id_value = os .getenv ("AZURE_ID_VALUE " , None )
662
+ id_value = os .getenv ("AZURE_USER_ASSIGNED_MANAGED_ID " , None )
637
663
638
664
if hasattr (request , "param" ):
639
665
kwargs = request .param .get ("idp_kwargs" , {})
640
666
else :
641
667
kwargs = {}
642
668
643
669
identity_type = kwargs .pop ("identity_type" , ManagedIdentityType .SYSTEM_ASSIGNED )
644
- id_type = kwargs .pop ("id_type" , ManagedIdentityIdType .CLIENT_ID )
670
+ id_type = kwargs .pop ("id_type" , ManagedIdentityIdType .OBJECT_ID )
645
671
646
- return create_provider_from_managed_identity (
672
+ return ManagedIdentityProviderConfig (
647
673
identity_type = identity_type ,
648
674
resource = resource ,
649
675
id_type = id_type ,
650
676
id_value = id_value ,
651
- authority = authority ,
652
- ** kwargs ,
677
+ kwargs = kwargs ,
653
678
)
654
679
655
680
656
- def _get_service_principal_provider (request ):
681
+ def _get_service_principal_provider_config (
682
+ request ,
683
+ ) -> ServicePrincipalIdentityProviderConfig :
657
684
client_id = os .getenv ("AZURE_CLIENT_ID" )
658
685
client_credential = os .getenv ("AZURE_CLIENT_SECRET" )
659
- authority = os .getenv ("AZURE_AUTHORITY " )
660
- scopes = os .getenv ("AZURE_REDIS_SCOPES" , [] )
686
+ tenant_id = os .getenv ("AZURE_TENANT_ID " )
687
+ scopes = os .getenv ("AZURE_REDIS_SCOPES" , None )
661
688
662
689
if hasattr (request , "param" ):
663
690
kwargs = request .param .get ("idp_kwargs" , {})
@@ -671,14 +698,14 @@ def _get_service_principal_provider(request):
671
698
if isinstance (scopes , str ):
672
699
scopes = scopes .split ("," )
673
700
674
- return create_provider_from_service_principal (
701
+ return ServicePrincipalIdentityProviderConfig (
675
702
client_id = client_id ,
676
703
client_credential = client_credential ,
677
704
scopes = scopes ,
678
705
timeout = timeout ,
679
706
token_kwargs = token_kwargs ,
680
- authority = authority ,
681
- ** kwargs ,
707
+ tenant_id = tenant_id ,
708
+ app_kwargs = kwargs ,
682
709
)
683
710
684
711
@@ -690,31 +717,29 @@ def get_credential_provider(request) -> CredentialProvider:
690
717
return cred_provider_class (** cred_provider_kwargs )
691
718
692
719
idp = identity_provider (request )
693
- initial_delay_in_ms = cred_provider_kwargs .get ("initial_delay_in_ms" , 0 )
694
- block_for_initial = cred_provider_kwargs .get ("block_for_initial" , False )
695
720
expiration_refresh_ratio = cred_provider_kwargs .get (
696
- "expiration_refresh_ratio" , TokenAuthConfig . DEFAULT_EXPIRATION_REFRESH_RATIO
721
+ "expiration_refresh_ratio" , DEFAULT_EXPIRATION_REFRESH_RATIO
697
722
)
698
723
lower_refresh_bound_millis = cred_provider_kwargs .get (
699
- "lower_refresh_bound_millis" , TokenAuthConfig .DEFAULT_LOWER_REFRESH_BOUND_MILLIS
700
- )
701
- max_attempts = cred_provider_kwargs .get (
702
- "max_attempts" , TokenAuthConfig .DEFAULT_MAX_ATTEMPTS
724
+ "lower_refresh_bound_millis" , DEFAULT_LOWER_REFRESH_BOUND_MILLIS
703
725
)
704
- delay_in_ms = cred_provider_kwargs .get (
705
- "delay_in_ms" , TokenAuthConfig .DEFAULT_DELAY_IN_MS
726
+ max_attempts = cred_provider_kwargs .get ("max_attempts" , DEFAULT_MAX_ATTEMPTS )
727
+ delay_in_ms = cred_provider_kwargs .get ("delay_in_ms" , DEFAULT_DELAY_IN_MS )
728
+
729
+ token_mgr_config = TokenManagerConfig (
730
+ expiration_refresh_ratio = expiration_refresh_ratio ,
731
+ lower_refresh_bound_millis = lower_refresh_bound_millis ,
732
+ token_request_execution_timeout_in_ms = DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS , # noqa
733
+ retry_policy = RetryPolicy (
734
+ max_attempts = max_attempts ,
735
+ delay_in_ms = delay_in_ms ,
736
+ ),
706
737
)
707
738
708
- auth_config = TokenAuthConfig (idp )
709
- auth_config .expiration_refresh_ratio = expiration_refresh_ratio
710
- auth_config .lower_refresh_bound_millis = lower_refresh_bound_millis
711
- auth_config .max_attempts = max_attempts
712
- auth_config .delay_in_ms = delay_in_ms
713
-
714
739
return EntraIdCredentialsProvider (
715
- config = auth_config ,
716
- initial_delay_in_ms = initial_delay_in_ms ,
717
- block_for_initial = block_for_initial ,
740
+ identity_provider = idp ,
741
+ token_manager_config = token_mgr_config ,
742
+ initial_delay_in_ms = delay_in_ms ,
718
743
)
719
744
720
745
0 commit comments