Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token refresh offset #12136

Merged
merged 33 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ def get_token(self, *scopes, **kwargs):
self._authorization_code = None # auth codes are single-use
return token

token = self._client.get_cached_access_token(scopes) or self._redeem_refresh_token(scopes, **kwargs)
token = self._client.get_cached_access_token(scopes)
if not token:
token = self._redeem_refresh_token(scopes, **kwargs)
elif self._client.should_refresh(token):
try:
self._redeem_refresh_token(scopes, **kwargs)
chlowell marked this conversation as resolved.
Show resolved Hide resolved
except Exception: # pylint: disable=broad-except
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be logging refreshes which fail here? Is this already done in _redeem_refresh_token?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question.

I am leaning towards not logging it because:

  • if there is a valid token available, user will continue to use that one and there is no need to log it.
  • if there is no valid token, user cannot get one and we will log that event (already implemented)


if not token:
raise ClientAuthenticationError(
message="No authorization code, cached access token, or refresh token available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
elif self._client.should_refresh(token):
try:
self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
except Exception: # pylint: disable=broad-except
pass
Comment on lines 49 to +55
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic:

if not token:
  # get new token
elif should_refrsesh:
  try:
     # get new token
  except Exception:
     # swallow

seems to be present in most if not all the credentials. Perhaps it could be moved into a base or mixin, and have the implementation just provide a callback or an override for the # get new token functionality?

Copy link
Member Author

@xiangyan99 xiangyan99 Jul 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. But different credentials have different ways to refresh/redeem tokens. So I have not found a clean way to do it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of something like this:

class CredentialBase(ABC):
    def __init__(self, **kwargs):
        self._client = AadClient(...)

    def _get_token_impl(*scopes, **kwargs):
        if not scopes:
            raise ValueError('"get_token" requires at least one scope')

        token = self._client.get_cached_access_token(scopes)
        if not token:
            token = self._request_token(scopes, **kwargs)
        elif self._client.should_refresh(token):
            try:
                self._request_token(scopes, **kwargs)
            except Exception:  # pylint:disable=broad-except
                pass
        return token

    @abc.abstractmethod
    def _request_token(self, *scopes, **kwargs):
        pass

class Credential(CredentialBase):
    def get_token(*scopes, **kwargs):
        """relevant user-facing docstring"""
        return self._get_token_impl(*scopes, **kwargs)

    def _request_token(*scopes, **kwargs):
        """get a new token according to this credential's personal idiom"""
        ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean we make a shared credential base?

I would like to have it into a separate issue/PR as code refactoring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring always has a lower priority than new features. Merging this code is an open-ended commitment to maintaining it as is, so it's worth investigating a better organization now. The one I sketched may have its own problems (e.g. multiple inheritance would require some care) but it seems workable. What do you think? Have you tried something similar already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when we do refactoring by adding a shared class for all credentials, we can do further than only this. But I don't want to rush it right before a release.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def get_token(self, *scopes, **kwargs):
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
elif self._client.should_refresh(token):
try:
self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ def get_token(self, *scopes, **kwargs):

token = self._client.get_cached_access_token(scopes)

if token:
return token
if not token:
token = self._redeem_refresh_token(scopes, **kwargs)
elif self._client.should_refresh(token):
try:
self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Sequence[str], **Any) -> Optional[AccessToken]
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
self._cache = cache or TokenCache()
self._client_id = client_id
self._pipeline = self._build_pipeline(**kwargs)
self._token_refresh_timeout = 30 # default 30s
self._token_refresh_offset = 120 # default 2 min
self._last_refresh_time = int(time.time())

def get_cached_access_token(self, scopes, query=None):
# type: (Sequence[str], Optional[dict]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
for token in tokens:
expires_on = int(token["expires_on"])
if expires_on - 300 > int(time.time()):
if expires_on - 30 > int(time.time()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be

Suggested change
if expires_on - 30 > int(time.time()):
if expires_on - self._token_refresh_timeout > int(time.time()):

or is there some rationale for always using 30 seconds?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not _token_refresh_timeout.

We don't have a clear design for this value but it must be less than _token_refresh_offset (default to 120). Or it will hide the auto refresh feature.

The old one 300 does not meet the requirement so I updated it to 30.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we need an explicit margin here. The 1s margin in if expires_on > int(time.time()) seems okay to me. My reasoning:

  • functionally, this line served to hardcode token_refresh_offset=300
    • if all cached tokens would expire within 300 seconds, this method would return None, prompting the caller to acquire a new token
  • token_refresh_offset will now be observed by callers of this method
  • when a caller enters its refresh window, it should begin trying to acquire a new token
  • while trying to acquire a new token, the caller should return any valid token it has

One bad outcome that could follow is the caller using a token that expires in flight. That request will fail, but the caller's other option was to raise without sending the request at all, because it couldn't acquire a new token. It seems better to try the request, which could after all succeed.

What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is when it is still in token_refresh_retry_timeout time frame.

Extreme case: user gets a token from us which expires in 1s. It is still in token_refresh_retry_timeout time frame so it does not get refreshed.

vs

They get None from us so it forces a refresh.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if the credential is waiting on the retry timeout, it won't try to get a new token, regardless of what it gets back from the cache. Returning None in that case only guarantees the current request will fail, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. if there is no valid token (it returns None), no matter it is in retry timeout window or not, we will try to get one.

Retry timeout only applies to there is A valid token but it is within the refresh offset window.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I overlooked this behavior. Credentials should observe the retry timeout when the cache is empty.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is the behavior we want. If we have no access_token and the first attempt to get one failed, do we really want to hold all requests for 30 seconds before attempting to get one? I think we need to clarify this more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is if there is no one available, every time user calls our library to get one, we will try it w/o a cool down time.

return AccessToken(token["secret"], expires_on)
return None

Expand All @@ -63,6 +66,19 @@ def get_cached_refresh_tokens(self, scopes):
"""Assumes all cached refresh tokens belong to the same user"""
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))

def should_refresh(self, token):
# type: (AccessToken) -> bool
""" check if the token needs refresh or not
"""
expires_on = int(token.expires_on)
now = int(time.time())
if expires_on - now > self._token_refresh_offset:
return False
if now - self._last_refresh_time < self._token_refresh_timeout:
return False
return True


@abc.abstractmethod
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
pass
Expand All @@ -85,6 +101,7 @@ def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):

def _process_response(self, response, request_time):
# type: (PipelineResponse, int) -> AccessToken
self._last_refresh_time = time.time() # no matter succeed or not, update the last refresh time

content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
token = self._client.get_cached_access_token(scopes)
if not token:
token = await self._redeem_refresh_token(scopes, **kwargs)

elif self._client.should_refresh(token):
try:
await self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass
if not token:
raise ClientAuthenticationError(
message="No authorization code, cached access token, or refresh token available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
elif self._client.should_refresh(token):
try:
await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
elif self._client.should_refresh(token):
try:
await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,16 @@ async def get_token(self, *scopes, **kwargs):
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_access_token(scopes)
if token:
return token
if not token:
token = await self._redeem_refresh_token(scopes, **kwargs)
elif self._client.should_refresh(token):
try:
await self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]":
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
35 changes: 35 additions & 0 deletions sdk/identity/azure-identity/tests/test_token_refresh_offset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time
from azure.identity._internal.aad_client import AadClient
from azure.core.credentials import AccessToken
import pytest

try:
from unittest import mock
except ImportError: # python < 3.3
import mock


def test_if_refresh():
client = AadClient("test", "test")
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
now = int(time.time())

# do not need refresh
token = AccessToken("token", now + 500)
should_refresh = client.should_refresh(token)
assert not should_refresh

# need refresh
token = AccessToken("token", now + 100)
client._last_refresh_time = now - 500
should_refresh = client.should_refresh(token)
assert should_refresh

# not exceed cool down time, do not refresh
token = AccessToken("token", now + 100)
client._last_refresh_time = now - 5
should_refresh = client.should_refresh(token)
assert not should_refresh
4 changes: 4 additions & 0 deletions sdk/identity/azure-identity/tests/test_vscode_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ def test_cache_refresh_token():


def test_no_obtain_token_if_cached():
def mock_should_refresh(token):
return False

expected_token = AccessToken("token", 42)

mock_client = mock.Mock(spec=object)
mock_client.should_refresh = mock_should_refresh
mock_client.obtain_token_by_refresh_token = mock.Mock(return_value=expected_token)
mock_client.get_cached_access_token = mock.Mock(return_value="VALUE")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ async def test_cache_refresh_token():

@pytest.mark.asyncio
async def test_no_obtain_token_if_cached():
def mock_should_refresh(token):
return False

expected_token = AccessToken("token", 42)

mock_client = mock.Mock(spec=object)
mock_client.should_refresh = mock_should_refresh
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
token_by_refresh_token = mock.Mock(return_value=expected_token)
mock_client.obtain_token_by_refresh_token = wrap_in_future(token_by_refresh_token)
mock_client.get_cached_access_token = mock.Mock(return_value="VALUE")
Expand Down