From dcaa3c2e2cbbd4f39239f43a738a4e57a862aa30 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Mon, 9 Sep 2019 16:33:26 -0700 Subject: [PATCH] Enable SSO on Windows (#7006) --- sdk/identity/azure-identity/HISTORY.md | 15 ++++ sdk/identity/azure-identity/README.md | 9 ++- .../azure-identity/azure/identity/__init__.py | 36 +++++++-- .../azure/identity/_authn_client.py | 77 +++++++++++++++++-- .../azure/identity/_constants.py | 3 + .../azure-identity/azure/identity/_version.py | 2 +- .../azure/identity/aio/__init__.py | 34 ++++++-- .../azure/identity/aio/_authn_client.py | 31 +++++++- .../azure/identity/aio/_internal/__init__.py | 7 ++ .../aio/_internal/exception_wrapper.py | 24 ++++++ .../azure/identity/aio/credentials.py | 49 +++++++++++- .../azure/identity/credentials.py | 71 ++++++++++++++++- sdk/identity/azure-identity/setup.py | 8 +- .../azure-identity/tests/test_identity.py | 56 ++++++++++++-- .../tests/test_identity_async.py | 44 ++++++++++- shared_requirements.txt | 1 + 16 files changed, 430 insertions(+), 37 deletions(-) create mode 100644 sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py create mode 100644 sdk/identity/azure-identity/azure/identity/aio/_internal/exception_wrapper.py diff --git a/sdk/identity/azure-identity/HISTORY.md b/sdk/identity/azure-identity/HISTORY.md index 68f2dc2d056e..a8dbbd7951c3 100644 --- a/sdk/identity/azure-identity/HISTORY.md +++ b/sdk/identity/azure-identity/HISTORY.md @@ -1,5 +1,20 @@ # Release History +## 1.0.0b3 (2019-09-10) +### New features: +- `SharedTokenCacheCredential` authenticates with tokens stored in a local +cache shared by Microsoft applications. This enables Azure SDK clients to +authenticate silently after you've signed in to Visual Studio 2019, for +example. `DefaultAzureCredential` includes `SharedTokenCacheCredential` when +the shared cache is available, and environment variable `AZURE_USERNAME` +is set. See the +[README](https://github.com/Azure/azure-sdk-for-python/blob/tree/master/sdk/identity/azure-identity/README.md#single-sign-on) +for more information. + +### Dependency changes: +- New dependency: [`msal-extensions`](https://pypi.org/project/msal-extensions/) +0.1.1 + ## 1.0.0b2 (2019-08-05) ### Breaking changes: - Removed `azure.core.Configuration` from the public API in preparation for a diff --git a/sdk/identity/azure-identity/README.md b/sdk/identity/azure-identity/README.md index 5df7ef702a6c..fbcbc9a71116 100644 --- a/sdk/identity/azure-identity/README.md +++ b/sdk/identity/azure-identity/README.md @@ -62,7 +62,7 @@ configuration: |credential class|identity|configuration |-|-|- -|`DefaultAzureCredential`|service principal, managed identity or user|none for managed identity; [environment variables](#environment-variables) for service principal or user authentication +|`DefaultAzureCredential`|service principal, managed identity, user|none for managed identity, [environment variables](#environment-variables) for service principal or user authentication |`ManagedIdentityCredential`|managed identity|none |`EnvironmentCredential`|service principal|[environment variables](#environment-variables) |`ClientSecretCredential`|service principal|constructor parameters @@ -93,6 +93,13 @@ require platform support. See the [managed identity documentation](https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/services-support-managed-identities) for more information. +### Single sign-on +During local development on Windows, `DefaultAzureCredential` can authenticate +using a single sign-on shared with Microsoft applications, for example Visual +Studio 2019. Because you may have multiple signed in identities, to +authenticate this way you must set the environment variable `AZURE_USERNAME` +with your desired identity's username (typically an email address). + ## Environment variables `DefaultAzureCredential` and `EnvironmentCredential` can be configured with diff --git a/sdk/identity/azure-identity/azure/identity/__init__.py b/sdk/identity/azure-identity/azure/identity/__init__.py index 73ae14193dde..b66b2b95ec85 100644 --- a/sdk/identity/azure-identity/azure/identity/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/__init__.py @@ -2,7 +2,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import os + from ._browser_auth import InteractiveBrowserCredential +from ._constants import EnvironmentVariables from .credentials import ( CertificateCredential, ChainedTokenCredential, @@ -10,6 +13,7 @@ DeviceCodeCredential, EnvironmentCredential, ManagedIdentityCredential, + SharedTokenCacheCredential, UsernamePasswordCredential, ) @@ -18,17 +22,34 @@ class DefaultAzureCredential(ChainedTokenCredential): """ A default credential capable of handling most Azure SDK authentication scenarios. - When environment variable configuration is present, it authenticates as a service principal - using :class:`azure.identity.EnvironmentCredential`. + The identity it uses depends on the environment. When an access token is needed, it requests one using these + identities in turn, stopping when one provides a token: - When environment configuration is not present, it authenticates with a managed identity - using :class:`azure.identity.ManagedIdentityCredential`. + 1. A service principal configured by environment variables. See :class:`~azure.identity.EnvironmentCredential` for + more details. + 2. An Azure managed identity. See :class:`~azure.identity.ManagedIdentityCredential` for more details. + 3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. This requires a + value for the environment variable ``AZURE_USERNAME``. See :class:`~azure.identity.SharedTokenCacheCredential` + for more details. """ def __init__(self, **kwargs): - super(DefaultAzureCredential, self).__init__( - EnvironmentCredential(**kwargs), ManagedIdentityCredential(**kwargs) - ) + credentials = [EnvironmentCredential(**kwargs), ManagedIdentityCredential(**kwargs)] + + # SharedTokenCacheCredential is part of the default only on supported platforms, when $AZURE_USERNAME has a + # value (because the cache may contain tokens for multiple identities and we can only choose one arbitrarily + # without more information from the user), and when $AZURE_PASSWORD has no value (because when $AZURE_USERNAME + # and $AZURE_PASSWORD are set, EnvironmentCredential will be used instead) + if ( + SharedTokenCacheCredential.supported() + and EnvironmentVariables.AZURE_USERNAME in os.environ + and EnvironmentVariables.AZURE_PASSWORD not in os.environ + ): + credentials.append( + SharedTokenCacheCredential(username=os.environ.get(EnvironmentVariables.AZURE_USERNAME), **kwargs) + ) + + super(DefaultAzureCredential, self).__init__(*credentials) __all__ = [ @@ -40,5 +61,6 @@ def __init__(self, **kwargs): "EnvironmentCredential", "InteractiveBrowserCredential", "ManagedIdentityCredential", + "SharedTokenCacheCredential", "UsernamePasswordCredential", ] diff --git a/sdk/identity/azure-identity/azure/identity/_authn_client.py b/sdk/identity/azure-identity/azure/identity/_authn_client.py index 03373511051c..dbb0d0028fa9 100644 --- a/sdk/identity/azure-identity/azure/identity/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/_authn_client.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import abc import calendar import time @@ -14,6 +15,12 @@ from azure.core.pipeline.policies import ContentDecodePolicy, NetworkTraceLoggingPolicy, ProxyPolicy, RetryPolicy from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy from azure.core.pipeline.transport import RequestsTransport +from azure.identity._constants import AZURE_CLI_CLIENT_ID + +try: + ABC = abc.ABC +except AttributeError: # Python 2.7, abc exists, but not ABC + ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore try: from typing import TYPE_CHECKING @@ -29,7 +36,7 @@ from azure.core.pipeline.policies import HTTPPolicy -class AuthnClientBase(object): +class AuthnClientBase(ABC): """Sans I/O authentication client methods""" def __init__(self, auth_url, **kwargs): # pylint:disable=unused-argument @@ -38,20 +45,48 @@ def __init__(self, auth_url, **kwargs): # pylint:disable=unused-argument raise ValueError("auth_url should be the URL of an OAuth endpoint") super(AuthnClientBase, self).__init__() self._auth_url = auth_url - self._cache = TokenCache() + self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache def get_cached_token(self, scopes): # type: (Iterable[str]) -> Optional[AccessToken] - tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, list(scopes)) + tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes)) for token in tokens: - if all((scope in token["target"] for scope in scopes)): - expires_on = int(token["expires_on"]) - if expires_on - 300 > int(time.time()): - return AccessToken(token["secret"], expires_on) + expires_on = int(token["expires_on"]) + if expires_on - 300 > int(time.time()): + return AccessToken(token["secret"], expires_on) return None + def get_refresh_tokens(self, scopes, account): + """Yields all an account's cached refresh tokens except those which have a scope (which is unexpected) that + isn't a superset of ``scopes``.""" + + for token in self._cache.find( + TokenCache.CredentialType.REFRESH_TOKEN, query={"home_account_id": account.get("home_account_id")} + ): + if "target" in token and not all((scope in token["target"] for scope in scopes)): + continue + yield token + + def get_refresh_token_grant_request(self, refresh_token, scopes): + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token["secret"], + "scope": " ".join(scopes), + "client_id": AZURE_CLI_CLIENT_ID, # TODO: first-party app for SDK? + } + return self._prepare_request(form_data=data) + + @abc.abstractmethod + def request_token(self, scopes, method, headers, form_data, params, **kwargs): + pass + + @abc.abstractmethod + def obtain_token_by_refresh_token(self, scopes, username): + pass + def _deserialize_and_cache_token(self, response, scopes, request_time): # type: (PipelineResponse, Iterable[str], int) -> AccessToken + """Deserialize and cache an access token from an AAD response""" # ContentDecodePolicy sets this, and should have raised if it couldn't deserialize the response payload = response.context[ContentDecodePolicy.CONTEXT_NAME] @@ -165,6 +200,34 @@ def request_token( token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time) return token + def obtain_token_by_refresh_token(self, scopes, username): + # type: (Iterable[str], str) -> Optional[AccessToken] + """Acquire an access token using a cached refresh token. Returns ``None`` when that fails, or the cache has no + refresh token. This is only used by SharedTokenCacheCredential and isn't robust enough for anything else.""" + + # find account matching username + accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query={"username": username}) + for account in accounts: + # try each refresh token that might work, return the first access token acquired + for token in self.get_refresh_tokens(scopes, account): + # currently we only support login.microsoftonline.com, which has an alias login.windows.net + # TODO: this must change to support sovereign clouds + environment = account.get("environment") + if not environment or (environment not in self._auth_url and environment != "login.windows.net"): + continue + + request = self.get_refresh_token_grant_request(token, scopes) + request_time = int(time.time()) + response = self._pipeline.run(request, stream=False) + try: + return self._deserialize_and_cache_token( + response=response, scopes=scopes, request_time=request_time + ) + except ClientAuthenticationError: + continue + + return None + @staticmethod def _create_config(**kwargs): # type: (Mapping[str, Any]) -> Configuration diff --git a/sdk/identity/azure-identity/azure/identity/_constants.py b/sdk/identity/azure-identity/azure/identity/_constants.py index 1c2608e5da8b..f497282c826e 100644 --- a/sdk/identity/azure-identity/azure/identity/_constants.py +++ b/sdk/identity/azure-identity/azure/identity/_constants.py @@ -4,6 +4,9 @@ # ------------------------------------ +AZURE_CLI_CLIENT_ID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" + + class EnvironmentVariables: AZURE_CLIENT_ID = "AZURE_CLIENT_ID" AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET" diff --git a/sdk/identity/azure-identity/azure/identity/_version.py b/sdk/identity/azure-identity/azure/identity/_version.py index 946e62e8cf55..4b82cfad0111 100644 --- a/sdk/identity/azure-identity/azure/identity/_version.py +++ b/sdk/identity/azure-identity/azure/identity/_version.py @@ -2,4 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -VERSION = "1.0.0b2" +VERSION = "1.0.0b3" diff --git a/sdk/identity/azure-identity/azure/identity/aio/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/__init__.py index 7610e4d26df3..99b33f6f2524 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/aio/__init__.py @@ -2,12 +2,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import os + +from .._constants import EnvironmentVariables from .credentials import ( CertificateCredential, ChainedTokenCredential, ClientSecretCredential, EnvironmentCredential, ManagedIdentityCredential, + SharedTokenCacheCredential, ) @@ -15,15 +19,34 @@ class DefaultAzureCredential(ChainedTokenCredential): """ A default credential capable of handling most Azure SDK authentication scenarios. - When environment variable configuration is present, it authenticates as a service principal - using :class:`azure.identity.aio.EnvironmentCredential`. + The identity it uses depends on the environment. When an access token is needed, it requests one using these + identities in turn, stopping when one provides a token: - When environment configuration is not present, it authenticates with a managed identity - using :class:`azure.identity.aio.ManagedIdentityCredential`. + 1. A service principal configured by environment variables. See :class:`~azure.identity.aio.EnvironmentCredential` + for more details. + 2. An Azure managed identity. See :class:`~azure.identity.aio.ManagedIdentityCredential` for more details. + 3. On Windows only: a user who has signed in with a Microsoft application, such as Visual Studio. This requires a + value for the environment variable ``AZURE_USERNAME``. See + :class:`~azure.identity.aio.SharedTokenCacheCredential` for more details. """ def __init__(self, **kwargs): - super().__init__(EnvironmentCredential(**kwargs), ManagedIdentityCredential(**kwargs)) + credentials = [EnvironmentCredential(**kwargs), ManagedIdentityCredential(**kwargs)] + + # SharedTokenCacheCredential is part of the default only on supported platforms, when $AZURE_USERNAME has a + # value (because the cache may contain tokens for multiple identities and we can only choose one arbitrarily + # without more information from the user), and when $AZURE_PASSWORD has no value (because when $AZURE_USERNAME + # and $AZURE_PASSWORD are set, EnvironmentCredential will be used instead) + if ( + SharedTokenCacheCredential.supported() + and EnvironmentVariables.AZURE_USERNAME in os.environ + and EnvironmentVariables.AZURE_PASSWORD not in os.environ + ): + credentials.append( + SharedTokenCacheCredential(username=os.environ.get(EnvironmentVariables.AZURE_USERNAME), **kwargs) + ) + + super().__init__(*credentials) __all__ = [ @@ -33,4 +56,5 @@ def __init__(self, **kwargs): "EnvironmentCredential", "ManagedIdentityCredential", "ChainedTokenCredential", + "SharedTokenCacheCredential", ] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py index bebb2c32929f..a6727907d050 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_authn_client.py @@ -5,10 +5,11 @@ import time from typing import Any, Dict, Iterable, Mapping, Optional +from msal import TokenCache from azure.core import Configuration from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline import AsyncPipeline -from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy from azure.core.pipeline.policies import ( AsyncRetryPolicy, ContentDecodePolicy, @@ -16,6 +17,7 @@ NetworkTraceLoggingPolicy, ProxyPolicy, ) +from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy from azure.core.pipeline.transport import AsyncHttpTransport from azure.core.pipeline.transport.requests_asyncio import AsyncioRequestsTransport @@ -61,6 +63,33 @@ async def request_token( token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time) return token + async def obtain_token_by_refresh_token(self, scopes: Iterable[str], username: str) -> Optional[AccessToken]: + """Acquire an access token using a cached refresh token. Returns ``None`` when that fails, or the cache has no + refresh token. This is only used by SharedTokenCacheCredential and isn't robust enough for anything else.""" + + # find account matching username + accounts = self._cache.find(TokenCache.CredentialType.ACCOUNT, query={"username": username}) + for account in accounts: + # try each refresh token that might work, return the first access token acquired + for token in self.get_refresh_tokens(scopes, account): + # currently we only support login.microsoftonline.com, which has an alias login.windows.net + # TODO: this must change to support sovereign clouds + environment = account.get("environment") + if not environment or (environment not in self._auth_url and environment != "login.windows.net"): + continue + + request = self.get_refresh_token_grant_request(token, scopes) + request_time = int(time.time()) + response = await self._pipeline.run(request, stream=False) + try: + return self._deserialize_and_cache_token( + response=response, scopes=scopes, request_time=request_time + ) + except ClientAuthenticationError: + continue + + return None + @staticmethod def _create_config(**kwargs: Mapping[str, Any]) -> Configuration: config = Configuration(**kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py new file mode 100644 index 000000000000..7ca58029041f --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/__init__.py @@ -0,0 +1,7 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from .exception_wrapper import wrap_exceptions + +__all__ = ["wrap_exceptions"] diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/exception_wrapper.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/exception_wrapper.py new file mode 100644 index 000000000000..4d6f40f954cb --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/exception_wrapper.py @@ -0,0 +1,24 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import functools + +from azure.core.exceptions import ClientAuthenticationError + + +def wrap_exceptions(fn): + """Prevents leaking exceptions defined outside azure-core by raising ClientAuthenticationError from them.""" + + @functools.wraps(fn) + async def wrapper(*args, **kwargs): + try: + result = await fn(*args, **kwargs) + return result + except ClientAuthenticationError: + raise + except Exception as ex: # pylint:disable=broad-except + auth_error = ClientAuthenticationError(message="Authentication failed: {}".format(ex)) + raise auth_error from ex + + return wrapper diff --git a/sdk/identity/azure-identity/azure/identity/aio/credentials.py b/sdk/identity/azure-identity/azure/identity/aio/credentials.py index 634a8ef5fef0..56b72bad1411 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/aio/credentials.py @@ -6,16 +6,25 @@ Credentials for asynchronous Azure SDK authentication. """ import os -from typing import Any, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Mapping, Optional, Union from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError from ._authn_client import AsyncAuthnClient +from ._internal import wrap_exceptions from ._managed_identity import ImdsCredential, MsiCredential from .._base import ClientSecretCredentialBase, CertificateCredentialBase from .._constants import Endpoints, EnvironmentVariables -from ..credentials import ChainedTokenCredential as SyncChainedTokenCredential +from ..credentials import ( + ChainedTokenCredential as SyncChainedTokenCredential, + SharedTokenCacheCredential as SyncSharedTokenCacheCredential, +) + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + import msal_extensions + from ._authn_client import AuthnClientBase # pylint:disable=too-few-public-methods @@ -186,3 +195,39 @@ async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint history.append((credential, str(ex))) error_message = self._get_error_message(history) raise ClientAuthenticationError(message=error_message) + + +class SharedTokenCacheCredential(SyncSharedTokenCacheCredential): + """ + Authenticates using tokens in the local cache shared between Microsoft applications. + + :param str username: + Username (typically an email address) of the user to authenticate as. This is required because the local cache + may contain tokens for multiple identities. + """ + + @wrap_exceptions + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """ + Get an access token for `scopes` from the shared cache. If no access token is cached, attempt to acquire one + using a cached refresh token. + + :param str scopes: desired scopes for the token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: + :class:`azure.core.exceptions.ClientAuthenticationError` when the cache is unavailable or no access token + can be acquired from it + """ + + if not self._client: + raise ClientAuthenticationError(message="Shared token cache unavailable") + + token = await self._client.obtain_token_by_refresh_token(scopes, self._username) + if not token: + raise ClientAuthenticationError(message="No cached token found for '{}'".format(self._username)) + + return token + + @staticmethod + def _get_auth_client(cache: "msal_extensions.FileTokenCache") -> "AuthnClientBase": + return AsyncAuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format("common"), cache=cache) diff --git a/sdk/identity/azure-identity/azure/identity/credentials.py b/sdk/identity/azure-identity/azure/identity/credentials.py index 652b840f3afd..091b6472c676 100644 --- a/sdk/identity/azure-identity/azure/identity/credentials.py +++ b/sdk/identity/azure-identity/azure/identity/credentials.py @@ -6,6 +6,7 @@ Credentials for Azure SDK authentication. """ import os +import sys import time from azure.core.credentials import AccessToken @@ -23,9 +24,11 @@ TYPE_CHECKING = False if TYPE_CHECKING: - # pylint:disable=unused-import + # pylint:disable=unused-import,ungrouped-imports from typing import Any, Callable, Dict, Mapping, Optional, Union from azure.core.credentials import TokenCredential + import msal_extensions + from ._authn_client import AuthnClientBase EnvironmentCredentialTypes = Union["CertificateCredential", "ClientSecretCredential", "UsernamePasswordCredential"] @@ -263,7 +266,7 @@ class DeviceCodeCredential(PublicClientCredential): """ def __init__(self, client_id, prompt_callback=None, **kwargs): - # type: (str, Optional[Callable[[str, str], None]], Any) -> None + # type: (str, Optional[Callable[[str, str, str], None]], Any) -> None self._timeout = kwargs.pop("timeout", None) # type: Optional[int] self._prompt_callback = prompt_callback super(DeviceCodeCredential, self).__init__(client_id=client_id, **kwargs) @@ -313,6 +316,70 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument return token +class SharedTokenCacheCredential(object): + """ + Authenticates using tokens in the local cache shared between Microsoft applications. + + :param str username: + Username (typically an email address) of the user to authenticate as. This is required because the local cache + may contain tokens for multiple identities. + """ + + def __init__(self, username, **kwargs): # pylint:disable=unused-argument + # type: (str, **Any) -> None + + self._username = username + + cache = None + + if sys.platform.startswith("win") and "LOCALAPPDATA" in os.environ: + from msal_extensions.token_cache import WindowsTokenCache + + cache = WindowsTokenCache(cache_location=os.environ["LOCALAPPDATA"] + "/.IdentityService/msal.cache") + + # prevent writing to the shared cache + # TODO: seperating deserializing access tokens from caching them would make this cleaner + cache.add = lambda *_: None + + if cache: + self._client = self._get_auth_client(cache) # type: Optional[AuthnClientBase] + else: + self._client = None + + @wrap_exceptions + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument + # type (*str, **Any) -> AccessToken + """ + Get an access token for `scopes` from the shared cache. If no access token is cached, attempt to acquire one + using a cached refresh token. + + :param str scopes: desired scopes for the token + :rtype: :class:`azure.core.credentials.AccessToken` + :raises: + :class:`azure.core.exceptions.ClientAuthenticationError` when the cache is unavailable or no access token + can be acquired from it + """ + + if not self._client: + raise ClientAuthenticationError(message="Shared token cache unavailable") + + token = self._client.obtain_token_by_refresh_token(scopes, self._username) + if not token: + raise ClientAuthenticationError(message="No cached token found for '{}'".format(self._username)) + + return token + + @staticmethod + def supported(): + # type: () -> bool + return sys.platform.startswith("win") + + @staticmethod + def _get_auth_client(cache): + # type: (msal_extensions.FileTokenCache) -> AuthnClientBase + return AuthnClient(Endpoints.AAD_OAUTH2_V2_FORMAT.format("common"), cache=cache) + + class UsernamePasswordCredential(PublicClientCredential): """ Authenticates a user with a username and password. In general, Microsoft doesn't recommend this kind of diff --git a/sdk/identity/azure-identity/setup.py b/sdk/identity/azure-identity/setup.py index aeb53fceeda4..896cfd8e2267 100644 --- a/sdk/identity/azure-identity/setup.py +++ b/sdk/identity/azure-identity/setup.py @@ -69,7 +69,13 @@ "azure", ] ), - install_requires=["azure-core<2.0.0,>=1.0.0b2", "cryptography>=2.1.4", "msal~=0.4.1", "six>=1.6"], + install_requires=[ + "azure-core<2.0.0,>=1.0.0b2", + "cryptography>=2.1.4", + "msal~=0.4.1", + "msal_extensions~=0.1.1", + "six>=1.6", + ], extras_require={ ":python_version<'3.0'": ["azure-nspkg"], ":python_version<'3.3'": ["mock"], diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index 40aa4f8df35d..799d50476523 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -11,7 +11,7 @@ try: from unittest.mock import Mock, patch except ImportError: # python < 3.3 - from mock import Mock, patch + from mock import Mock, patch # type: ignore from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError @@ -236,8 +236,45 @@ def test_imds_credential_retries(): assert mock_send.call_count == 2 + total_retries -def test_default_credential(): - DefaultAzureCredential() +@patch("azure.identity.SharedTokenCacheCredential") +def test_default_credential_shared_cache_use(mock_credential): + mock_credential.supported = Mock(return_value=False) + + # unsupported platform -> default credential shouldn't use shared cache + credential = DefaultAzureCredential() + assert mock_credential.call_count == 0 + assert mock_credential.supported.call_count == 1 + mock_credential.supported.reset_mock() + + # unsupported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential shouldn't use shared cache + credential = DefaultAzureCredential() + assert mock_credential.call_count == 0 + assert mock_credential.supported.call_count == 1 + + mock_credential.supported = Mock(return_value=True) + + # supported platform, $AZURE_USERNAME not set -> default credential shouldn't use shared cache + credential = DefaultAzureCredential() + assert mock_credential.call_count == 0 + assert mock_credential.supported.call_count == 1 + mock_credential.supported.reset_mock() + + # supported platform, $AZURE_USERNAME and $AZURE_PASSWORD set -> default credential shouldn't use shared cache + # (EnvironmentCredential should be used when both variables are set) + with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com", "AZURE_PASSWORD": "***"}): + credential = DefaultAzureCredential() + assert mock_credential.call_count == 0 + + # supported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential should use shared cache + with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com"}): + expected_token = AccessToken("***", 42) + mock_credential.return_value = Mock(get_token=lambda *_: expected_token) + + credential = DefaultAzureCredential() + assert mock_credential.call_count == 1 + + token = credential.get_token("scope") + assert token == expected_token def test_device_code_credential(): @@ -252,7 +289,12 @@ def test_device_code_credential(): # expected requests: discover tenant, start device code flow, poll for completion mock_response(json_payload={"authorization_endpoint": "https://a/b", "token_endpoint": "https://a/b"}), mock_response( - json_payload={"device_code": "_", "user_code": user_code, "verification_uri": verification_uri, "expires_in": expires_in} + json_payload={ + "device_code": "_", + "user_code": user_code, + "verification_uri": verification_uri, + "expires_in": expires_in, + } ), mock_response( json_payload={ @@ -291,7 +333,7 @@ def test_device_code_credential_timeout(): ) credential = DeviceCodeCredential( - client_id="_", prompt_callback=Mock(), transport=transport, timeout=0.1, instance_discovery=False + client_id="_", prompt_callback=Mock(), transport=transport, timeout=0.01, instance_discovery=False ) with pytest.raises(ClientAuthenticationError) as ex: @@ -348,8 +390,8 @@ def test_interactive_credential_timeout(): ) # mock local server blocks long enough to exceed the timeout - timeout = 1 - server_instance = Mock(wait_for_redirect=functools.partial(time.sleep, timeout + 1)) + timeout = 0.01 + server_instance = Mock(wait_for_redirect=functools.partial(time.sleep, timeout + 0.01)) server_class = Mock(return_value=server_instance) credential = InteractiveBrowserCredential( diff --git a/sdk/identity/azure-identity/tests/test_identity_async.py b/sdk/identity/azure-identity/tests/test_identity_async.py index ce432f62473b..204c100a83e5 100644 --- a/sdk/identity/azure-identity/tests/test_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_identity_async.py @@ -6,7 +6,7 @@ import json import os import time -from unittest.mock import Mock +from unittest.mock import Mock, patch import uuid import pytest @@ -280,5 +280,43 @@ async def validate_request(req, *args, **kwargs): assert ex.value.message is success_message -def test_default_credential(): - DefaultAzureCredential() +@pytest.mark.asyncio +async def test_default_credential_shared_cache_use(): + with patch("azure.identity.aio.SharedTokenCacheCredential") as mock_credential: + mock_credential.supported = Mock(return_value=False) + + # unsupported platform -> default credential shouldn't use shared cache + credential = DefaultAzureCredential() + assert mock_credential.call_count == 0 + assert mock_credential.supported.call_count == 1 + mock_credential.supported.reset_mock() + + # unsupported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential shouldn't use shared cache + credential = DefaultAzureCredential() + assert mock_credential.call_count == 0 + assert mock_credential.supported.call_count == 1 + + mock_credential.supported = Mock(return_value=True) + + # supported platform, $AZURE_USERNAME not set -> default credential shouldn't use shared cache + credential = DefaultAzureCredential() + assert mock_credential.call_count == 0 + assert mock_credential.supported.call_count == 1 + mock_credential.supported.reset_mock() + + # supported platform, $AZURE_USERNAME and $AZURE_PASSWORD set -> default credential shouldn't use shared cache + # (EnvironmentCredential should be used when both variables are set) + with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com", "AZURE_PASSWORD": "***"}): + credential = DefaultAzureCredential() + assert mock_credential.call_count == 0 + + # supported platform, $AZURE_USERNAME set, $AZURE_PASSWORD not set -> default credential should use shared cache + with patch.dict("os.environ", {"AZURE_USERNAME": "foo@bar.com"}): + expected_token = AccessToken("***", 42) + mock_credential.return_value=Mock(get_token=asyncio.coroutine(lambda *_: expected_token)) + + credential = DefaultAzureCredential() + assert mock_credential.call_count == 1 + + token = await credential.get_token("scope") + assert token == expected_token diff --git a/shared_requirements.txt b/shared_requirements.txt index c03066fba99e..329d6a2fb6db 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -90,6 +90,7 @@ futures mock typing msal~=0.4.1 +msal_extensions~=0.1.1 msrest>=0.5.0 msrestazure<2.0.0,>=0.4.32 requests>=2.18.4