Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Credentials accept tenant_id argument to get_token #19602

Merged
merged 19 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
SharedTokenCacheCredential
  • Loading branch information
chlowell committed Jul 6, 2021
commit 1d286c9479b0293a4a1710f337932937ecbd5313
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .. import CredentialUnavailableError
from .._constants import DEVELOPER_SIGN_ON_CLIENT_ID
from .._internal import AadClient, validate_tenant_id
from .._internal import AadClient, resolve_tenant, validate_tenant_id
from .._internal.decorators import log_get_token, wrap_exceptions
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase
Expand All @@ -24,7 +24,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Optional
from typing import Any, Dict, Optional
from .. import AuthenticationRecord
from .._internal import AadClientBase

Expand All @@ -46,6 +46,10 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
:keyword cache_persistence_options: configuration for persistent token caching. If not provided, the credential
will use the persistent cache shared by Microsoft development applications
:paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions
:keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant
the user is registered in. When False, which is the default, the credential will acquire tokens only from the
user's home tenant or, if a value was given for **authentication_record**, the tenant specified by the
:class:`AuthenticationRecord`.
"""

def __init__(self, username=None, **kwargs):
Expand All @@ -56,9 +60,10 @@ def __init__(self, username=None, **kwargs):
# authenticate in the tenant that produced the record unless "tenant_id" specifies another
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
validate_tenant_id(self._tenant_id)
self._allow_multitenant = kwargs.pop("allow_multitenant_authentication", False)
self._cache = kwargs.pop("_cache", None)
self._app = None
self._client_kwargs = kwargs
self._client_applications = {} # type: Dict[str, PublicClientApplication]
self._msal_client = MsalClient(**kwargs)
self._initialized = False
else:
super(SharedTokenCacheCredential, self).__init__(username=username, **kwargs)
Expand Down Expand Up @@ -101,7 +106,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account):
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token)
token = self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs)
return token

raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))
Expand All @@ -119,34 +124,35 @@ def _initialize(self):
return

self._load_cache()
if self._cache:
if "AZURE_IDENTITY_DISABLE_CP1" in os.environ:
capabilities = None
else:
capabilities = ["CP1"] # able to handle CAE claims challenges
self._app = PublicClientApplication(
self._initialized = True

def _get_client_application(self, **kwargs):
tenant_id = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs)
if tenant_id not in self._client_applications:
# CP1 = can handle claims challenges (CAE)
capabilities = None if "AZURE_IDENTITY_DISABLE_CP1" in os.environ else ["CP1"]
self._client_applications[tenant_id] = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id),
authority="https://{}/{}".format(self._auth_record.authority, tenant_id),
token_cache=self._cache,
http_client=MsalClient(**self._client_kwargs),
http_client=self._msal_client,
client_capabilities=capabilities
)

self._initialized = True
return self._client_applications[tenant_id]

@wrap_exceptions
def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Silently acquire a token from MSAL. Requires an AuthenticationRecord."""

# self._auth_record and ._app will not be None when this method is called by get_token
# but should either be None anyway (and to satisfy mypy) we raise
if self._app is None or self._auth_record is None:
# this won't be None when this method is called by get_token but we check anyway to satisfy mypy
if self._auth_record is None:
raise CredentialUnavailableError("Initialization failed")

result = None

accounts_for_user = self._app.get_accounts(username=self._auth_record.username)
client_application = self._get_client_application(**kwargs)
accounts_for_user = client_application.get_accounts(username=self._auth_record.username)
if not accounts_for_user:
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")

Expand All @@ -155,7 +161,7 @@ def _acquire_token_silent(self, *scopes, **kwargs):
continue

now = int(time.time())
result = self._app.acquire_token_silent_with_error(
result = client_application.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
if result and "access_token" in result and "expires_in" in result:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncContextManager):
:keyword cache_persistence_options: configuration for persistent token caching. If not provided, the credential
will use the persistent cache shared by Microsoft development applications
:paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions
:keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant
the user is registered in. When False, which is the default, the credential will acquire tokens only from the
user's home tenant.
"""

async def __aenter__(self):
Expand Down Expand Up @@ -78,7 +81,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py

# try each refresh token, returning the first access token acquired
for refresh_token in self._get_refresh_tokens(account):
token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token)
token = await self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs)
return token

raise CredentialUnavailableError(message=NO_TOKEN.format(account.get("username")))
Expand Down