Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token refresh offset #12136

Merged
merged 33 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
UserAgentPolicy,
)
from azure.core.pipeline.transport import RequestsTransport, HttpRequest
from ._constants import AZURE_CLI_CLIENT_ID
from ._constants import AZURE_CLI_CLIENT_ID, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
from ._internal import get_default_authority, normalize_authority
from ._internal.user_agent import USER_AGENT

Expand Down Expand Up @@ -65,17 +65,32 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl
authority = normalize_authority(authority) if authority else get_default_authority()
self._auth_url = "/".join((authority, tenant.strip("/"), "oauth2/v2.0/token"))
self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache
self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY
self._token_refresh_offset = DEFAULT_REFRESH_OFFSET
self._last_refresh_time = 0

@property
def auth_url(self):
return self._auth_url

def should_refresh(self, token):
# type: (AccessToken) -> bool
""" check if the token needs refresh or not
"""
expires_on = int(token.expires_on)
now = int(time.time())
if expires_on - now > self._token_refresh_offset:
return False
if now - self._last_refresh_time < self._token_refresh_retry_delay:
return False
return True

def get_cached_token(self, scopes):
# type: (Iterable[str]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes))
for token in tokens:
expires_on = int(token["expires_on"])
if expires_on - 300 > int(time.time()):
if expires_on > int(time.time()):
return AccessToken(token["secret"], expires_on)
return None

Expand Down Expand Up @@ -217,6 +232,7 @@ def request_token(
# type: (...) -> AccessToken
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
request_time = int(time.time())
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time
response = self._pipeline.run(request, stream=False, **kwargs)
token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
return token
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
AZURE_CLI_CLIENT_ID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46"
AZURE_VSCODE_CLIENT_ID = "aebc6443-996d-45c2-90f0-388ff96faa56"
VSCODE_CREDENTIALS_SECTION = "VS Code Azure"
DEFAULT_REFRESH_OFFSET = 300
DEFAULT_TOKEN_REFRESH_RETRY_DELAY = 30


class KnownAuthorities:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ def get_token(self, *scopes, **kwargs):
self._authorization_code = None # auth codes are single-use
return token

token = self._client.get_cached_access_token(scopes) or self._redeem_refresh_token(scopes, **kwargs)
token = self._client.get_cached_access_token(scopes)
if not token:
token = self._redeem_refresh_token(scopes, **kwargs)
elif self._client.should_refresh(token):
try:
self._redeem_refresh_token(scopes, **kwargs)
chlowell marked this conversation as resolved.
Show resolved Hide resolved
except Exception: # pylint: disable=broad-except
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be logging refreshes which fail here? Is this already done in _redeem_refresh_token?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question.

I am leaning towards not logging it because:

  • if there is a valid token available, user will continue to use that one and there is no need to log it.
  • if there is no valid token, user cannot get one and we will log that event (already implemented)


if not token:
raise ClientAuthenticationError(
message="No authorization code, cached access token, or refresh token available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
elif self._client.should_refresh(token):
try:
self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
except Exception: # pylint: disable=broad-except
pass
Comment on lines 49 to +55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic:

if not token:
  # get new token
elif should_refrsesh:
  try:
     # get new token
  except Exception:
     # swallow

seems to be present in most if not all the credentials. Perhaps it could be moved into a base or mixin, and have the implementation just provide a callback or an override for the # get new token functionality?

Copy link
Member Author

@xiangyan99 xiangyan99 Jul 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. But different credentials have different ways to refresh/redeem tokens. So I have not found a clean way to do it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of something like this:

class CredentialBase(ABC):
    def __init__(self, **kwargs):
        self._client = AadClient(...)

    def _get_token_impl(*scopes, **kwargs):
        if not scopes:
            raise ValueError('"get_token" requires at least one scope')

        token = self._client.get_cached_access_token(scopes)
        if not token:
            token = self._request_token(scopes, **kwargs)
        elif self._client.should_refresh(token):
            try:
                self._request_token(scopes, **kwargs)
            except Exception:  # pylint:disable=broad-except
                pass
        return token

    @abc.abstractmethod
    def _request_token(self, *scopes, **kwargs):
        pass

class Credential(CredentialBase):
    def get_token(*scopes, **kwargs):
        """relevant user-facing docstring"""
        return self._get_token_impl(*scopes, **kwargs)

    def _request_token(*scopes, **kwargs):
        """get a new token according to this credential's personal idiom"""
        ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean we make a shared credential base?

I would like to have it into a separate issue/PR as code refactoring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring always has a lower priority than new features. Merging this code is an open-ended commitment to maintaining it as is, so it's worth investigating a better organization now. The one I sketched may have its own problems (e.g. multiple inheritance would require some care) but it seems workable. What do you think? Have you tried something similar already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when we do refactoring by adding a shared class for all credentials, we can do further than only this. But I don't want to rush it right before a release.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def get_token(self, *scopes, **kwargs):
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
elif self._client.should_refresh(token):
try:
self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,28 +170,37 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

token = self._client.get_cached_token(scopes)
if not token:
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[: -len("/.default")]
params = dict({"api-version": "2018-02-01", "resource": resource}, **self._identity_config)

token = self._refresh_token(*scopes)
elif self._client.should_refresh(token):
try:
token = self._client.request_token(scopes, method="GET", params=params)
except HttpResponseError as ex:
# 400 in response to a token request indicates managed identity is disabled,
# or the identity with the specified client_id is not available
if ex.status_code == 400:
self._endpoint_available = False
message = "ManagedIdentityCredential authentication unavailable. "
if self._identity_config:
message += "The requested identity has not been assigned to this resource."
else:
message += "No identity has been assigned to this resource."
six.raise_from(CredentialUnavailableError(message=message), ex)

# any other error is unexpected
six.raise_from(ClientAuthenticationError(message=ex.message, response=ex.response), None)
token = self._refresh_token(*scopes)
except Exception: # pylint: disable=broad-except
pass

return token

def _refresh_token(self, *scopes):
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[: -len("/.default")]
params = dict({"api-version": "2018-02-01", "resource": resource}, **self._identity_config)

try:
token = self._client.request_token(scopes, method="GET", params=params)
except HttpResponseError as ex:
# 400 in response to a token request indicates managed identity is disabled,
# or the identity with the specified client_id is not available
if ex.status_code == 400:
self._endpoint_available = False
message = "ManagedIdentityCredential authentication unavailable. "
if self._identity_config:
message += "The requested identity has not been assigned to this resource."
else:
message += "No identity has been assigned to this resource."
six.raise_from(CredentialUnavailableError(message=message), ex)

# any other error is unexpected
six.raise_from(ClientAuthenticationError(message=ex.message, response=ex.response), None)
return token


Expand Down Expand Up @@ -227,16 +236,25 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument

token = self._client.get_cached_token(scopes)
if not token:
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[: -len("/.default")]
secret = os.environ.get(EnvironmentVariables.MSI_SECRET)
if secret:
# MSI_ENDPOINT and MSI_SECRET set -> App Service
token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret)
else:
# only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell)
token = self._request_legacy_token(scopes=scopes, resource=resource)
token = self._refresh_token(*scopes)
elif self._client.should_refresh(token):
try:
token = self._refresh_token(*scopes)
except Exception: # pylint: disable=broad-except
pass
return token

def _refresh_token(self, *scopes):
resource = scopes[0]
if resource.endswith("/.default"):
resource = resource[: -len("/.default")]
secret = os.environ.get(EnvironmentVariables.MSI_SECRET)
if secret:
# MSI_ENDPOINT and MSI_SECRET set -> App Service
token = self._request_app_service_token(scopes=scopes, resource=resource, secret=secret)
else:
# only MSI_ENDPOINT set -> legacy-style MSI (Cloud Shell)
token = self._request_legacy_token(scopes=scopes, resource=resource)
return token

def _request_app_service_token(self, scopes, resource, secret):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from .._internal.aad_client import AadClient

if sys.platform.startswith("win"):
from .win_vscode_adapter import get_credentials
from .._internal.win_vscode_adapter import get_credentials
elif sys.platform.startswith("darwin"):
from .macos_vscode_adapter import get_credentials
from .._internal.macos_vscode_adapter import get_credentials
else:
from .linux_vscode_adapter import get_credentials
from .._internal.linux_vscode_adapter import get_credentials

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
Expand Down Expand Up @@ -47,9 +47,17 @@ def get_token(self, *scopes, **kwargs):

token = self._client.get_cached_access_token(scopes)

if token:
return token
if not token:
token = self._redeem_refresh_token(scopes, **kwargs)
elif self._client.should_refresh(token):
try:
self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Sequence[str], **Any) -> Optional[AccessToken]
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from . import get_default_authority, normalize_authority
from .._constants import DEFAULT_TOKEN_REFRESH_RETRY_DELAY, DEFAULT_REFRESH_OFFSET

try:
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -48,13 +49,16 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
self._cache = cache or TokenCache()
self._client_id = client_id
self._pipeline = self._build_pipeline(**kwargs)
self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY
self._token_refresh_offset = DEFAULT_REFRESH_OFFSET
self._last_refresh_time = 0

def get_cached_access_token(self, scopes, query=None):
# type: (Sequence[str], Optional[dict]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
for token in tokens:
expires_on = int(token["expires_on"])
if expires_on - 300 > int(time.time()):
if expires_on > int(time.time()):
return AccessToken(token["secret"], expires_on)
return None

Expand All @@ -63,6 +67,19 @@ def get_cached_refresh_tokens(self, scopes):
"""Assumes all cached refresh tokens belong to the same user"""
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))

def should_refresh(self, token):
# type: (AccessToken) -> bool
""" check if the token needs refresh or not
"""
expires_on = int(token.expires_on)
now = int(time.time())
if expires_on - now > self._token_refresh_offset:
return False
if now - self._last_refresh_time < self._token_refresh_retry_delay:
return False
return True


@abc.abstractmethod
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
pass
Expand All @@ -85,6 +102,7 @@ def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):

def _process_response(self, response, request_time):
# type: (PipelineResponse, int) -> AccessToken
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time

content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def request_token(
) -> AccessToken:
request = self._prepare_request(method, headers=headers, form_data=form_data, params=params)
request_time = int(time.time())
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time
response = await self._pipeline.run(request, stream=False, **kwargs)
token = self._deserialize_and_cache_token(response=response, scopes=scopes, request_time=request_time)
return token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
token = self._client.get_cached_access_token(scopes)
if not token:
token = await self._redeem_refresh_token(scopes, **kwargs)

elif self._client.should_refresh(token):
try:
await self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass
if not token:
raise ClientAuthenticationError(
message="No authorization code, cached access token, or refresh token available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
elif self._client.should_refresh(token):
try:
await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
elif self._client.should_refresh(token):
try:
await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Loading