Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose methods for closing async credential transport sessions #9090

Merged
merged 19 commits into from
Jan 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions sdk/core/azure-core/azure/core/credentials_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any
from typing_extensions import Protocol
from .credentials import AccessToken

class AsyncTokenCredential(Protocol):
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
pass

async def close(self) -> None:
pass

async def __aenter__(self):
pass

async def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ class _BearerTokenCredentialPolicyBase(object):
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (TokenCredential, *str, Mapping[str, Any]) -> None
def __init__(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> None
super(_BearerTokenCredentialPolicyBase, self).__init__()
self._scopes = scopes
self._credential = credential
self._token = None # type: Optional[AccessToken]

@staticmethod
Expand Down Expand Up @@ -69,6 +68,11 @@ class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPo
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

def __init__(self, credential, *scopes, **kwargs):
# type: (TokenCredential, *str, **Any) -> None
self._credential = credential
super(BearerTokenCredentialPolicy, self).__init__(*scopes, **kwargs)

def on_request(self, request):
# type: (PipelineRequest) -> None
"""Adds a bearer token Authorization header to request and sends request to next policy.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,33 @@
# license information.
# -------------------------------------------------------------------------
import threading
from typing import TYPE_CHECKING

from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest


class AsyncBearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, SansIOHTTPPolicy):
# pylint:disable=too-few-public-methods
"""Adds a bearer token Authorization header to requests.

:param credential: The credential.
:type credential: ~azure.core.credentials.TokenCredential
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
:param str scopes: Lets you specify the type of access needed.
"""

def __init__(self, credential, *scopes, **kwargs):
super().__init__(credential, *scopes, **kwargs)
def __init__(self, credential: "AsyncTokenCredential", *scopes: str, **kwargs: "Any") -> None:
self._credential = credential
self._lock = threading.Lock()
super().__init__(*scopes, **kwargs)

async def on_request(self, request: PipelineRequest):
async def on_request(self, request: "PipelineRequest"):
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

- All credential pipelines include `ProxyPolicy`
([#8945](https://github.com/Azure/azure-sdk-for-python/pull/8945))
- Async credentials are async context managers and have an async `close` method
([#9090](https://github.com/Azure/azure-sdk-for-python/pull/9090))


## 1.1.0 (2019-11-27)
Expand Down
18 changes: 18 additions & 0 deletions sdk/identity/azure-identity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,24 @@ async transport, such as [aiohttp](https://pypi.org/project/aiohttp/). See
[azure-core documentation](../../core/azure-core/README.md#transport)
for more information.

Async credentials should be closed when they're no longer needed. Each async
credential is an async context manager and defines an async `close` method. For
example:

```py
from azure.identity.aio import DefaultAzureCredential

# call close when the credential is no longer needed
credential = DefaultAzureCredential()
...
await credential.close()

# alternatively, use the credential as an async context manager
credential = DefaultAzureCredential()
async with credential:
...
```

This example demonstrates authenticating the asynchronous `SecretClient` from
[azure-keyvault-secrets][azure_keyvault_secrets] with an asynchronous
credential.
Expand Down
26 changes: 14 additions & 12 deletions sdk/identity/azure-identity/azure/identity/_credentials/chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
from azure.core.credentials import AccessToken, TokenCredential


def _get_error_message(history):
attempts = []
for credential, error in history:
if error:
attempts.append("{}: {}".format(credential.__class__.__name__, error))
else:
attempts.append(credential.__class__.__name__)
return """No credential in this chain provided a token.
Attempted credentials:\n\t{}""".format(
"\n\t".join(attempts)
)


class ChainedTokenCredential(object):
"""A sequence of credentials that is itself a credential.

Expand Down Expand Up @@ -48,16 +61,5 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
history.append((credential, ex.message))
except Exception as ex: # pylint: disable=broad-except
history.append((credential, str(ex)))
error_message = self._get_error_message(history)
error_message = _get_error_message(history)
raise ClientAuthenticationError(message=error_message)

@staticmethod
def _get_error_message(history):
attempts = []
for credential, error in history:
if error:
attempts.append("{}: {}".format(credential.__class__.__name__, error))
else:
attempts.append(credential.__class__.__name__)
return """No credential in this chain provided a token.
Attempted credentials:\n\t{}""".format("\n\t".join(attempts))
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/azure/identity/aio/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def __init__(
self._pipeline = AsyncPipeline(transport=transport, policies=policies)
super().__init__(**kwargs)

async def __aenter__(self):
await self._pipeline.__aenter__()
return self

async def __aexit__(self, *args):
await self.close()

async def close(self) -> None:
await self._pipeline.__aexit__()

async def request_token(
self,
scopes: "Iterable[str]",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import abc


class AsyncCredentialBase(abc.ABC):
@abc.abstractmethod
async def close(self):
pass

async def __aenter__(self):
lmazuel marked this conversation as resolved.
Show resolved Hide resolved
return self

async def __aexit__(self, *args):
await self.close()

@abc.abstractmethod
async def get_token(self, *scopes, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import asyncio
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError
from ... import ChainedTokenCredential as SyncChainedTokenCredential
from .base import AsyncCredentialBase
from ..._credentials.chained import _get_error_message

if TYPE_CHECKING:
from typing import Any
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential


class ChainedTokenCredential(SyncChainedTokenCredential):
class ChainedTokenCredential(AsyncCredentialBase):
"""A sequence of credentials that is itself a credential.

Its :func:`get_token` method calls ``get_token`` on each credential in the sequence, in order, returning the first
valid token received.

:param credentials: credential instances to form the chain
:type credentials: :class:`azure.core.credentials.TokenCredential`
:type credentials: :class:`azure.core.credentials.AsyncTokenCredential`
"""

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
def __init__(self, *credentials: "AsyncTokenCredential") -> None:
if not credentials:
raise ValueError("at least one credential is required")
self.credentials = credentials

async def close(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a concern here: do we expect customer to "async enter" all the credentials in the chain, or should we have a aenter here that loop thourgh all of them and enter them?
Can I see a sample of usage of this one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I want to enable:

credential = DefaultAzureCredential()
client = FooServiceClient(credential)
# ... time passes, many useful service requests are authorized ...
credential.close()

I think close is the important API. I don't expect anyone to "enter" or "open" a credential. I have aenter doing nothing here because the credential doesn't know which members of its chain will send requests, and transports will open sessions as needed. (At least, our current transport implementations will. Perhaps we should make explicit who's expected to open a transport.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you implement __aenter__/__aexit__ (e.g. you are an async context manager), then you are strongly signalling to users that using async with is general goodness, but you can do the closing yourself if you so see fit.

If we don't want to give an example of intended use with async with, then why is it an async context manager?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be consistent with everything else in the SDK that wraps async transport by exposing __aenter__, __aexit__, and close. This PR does add an async with example to the README, and it's okay to use credentials that way. Every other credential's __aenter__ invokes its transport's __aenter__.

The awkwardness for ChainedTokenCredential.__aenter__ is that if it opens sessions for N credentials, N-1 of them may never be used, at some cost dependent on the HTTP client's implementation. These sessions will all be closed by __aexit__, but I thought it unnecessary to open them given that our async transports will do so as needed.

"""Close the transport sessions of all credentials in the chain."""

await asyncio.gather(*(credential.close() for credential in self.credentials))

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
"""Asynchronously request a token from each credential, in order, returning the first token received.

If no credential provides a token, raises :class:`azure.core.exceptions.ClientAuthenticationError`
Expand All @@ -41,5 +54,5 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
history.append((credential, ex.message))
except Exception as ex: # pylint: disable=broad-except
history.append((credential, str(ex)))
error_message = self._get_error_message(history)
error_message = _get_error_message(history)
raise ClientAuthenticationError(message=error_message)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# ------------------------------------
from typing import TYPE_CHECKING

from .base import AsyncCredentialBase
from .._authn_client import AsyncAuthnClient
from ..._base import ClientSecretCredentialBase, CertificateCredentialBase

Expand All @@ -12,7 +13,7 @@
from azure.core.credentials import AccessToken


class ClientSecretCredential(ClientSecretCredentialBase):
class ClientSecretCredential(ClientSecretCredentialBase, AsyncCredentialBase):
"""Authenticates as a service principal using a client ID and client secret.

:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
Expand All @@ -28,6 +29,15 @@ def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs:
super(ClientSecretCredential, self).__init__(tenant_id, client_id, client_secret, **kwargs)
self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs)

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

await self._client.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
"""Asynchronously request an access token for `scopes`.

Expand All @@ -44,7 +54,7 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
return token # type: ignore


class CertificateCredential(CertificateCredentialBase):
class CertificateCredential(CertificateCredentialBase, AsyncCredentialBase):
"""Authenticates as a service principal using a certificate.

:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
Expand All @@ -57,6 +67,15 @@ class CertificateCredential(CertificateCredentialBase):
defines authorities for other clouds.
"""

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

await self._client.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
"""Asynchronously request an access token for `scopes`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from azure.core.exceptions import ClientAuthenticationError
from ..._constants import EnvironmentVariables
from .client_credential import CertificateCredential, ClientSecretCredential
from .base import AsyncCredentialBase

if TYPE_CHECKING:
from typing import Any, Optional, Union
from azure.core.credentials import AccessToken


class EnvironmentCredential:
class EnvironmentCredential(AsyncCredentialBase):
"""A credential configured by environment variables.

This credential is capable of authenticating as a service principal using a client secret or a certificate, or as
Expand Down Expand Up @@ -50,6 +51,17 @@ def __init__(self, **kwargs: "Any") -> None:
**kwargs
)

async def __aenter__(self):
if self._credential:
await self._credential.__aenter__()
return self

async def close(self):
"""Close the credential's transport session."""

if self._credential:
await self._credential.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
"""Asynchronously request an access token for `scopes`.

Expand Down
Loading