Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token refresh offset #12136

Merged
merged 33 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
token refresh offset
  • Loading branch information
xiangyan99 committed Jun 19, 2020
commit dc1a9b2687521cad623e6044bbf0d7f940d63b2b
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ def get_token(self, *scopes, **kwargs):
self._authorization_code = None # auth codes are single-use
return token

token = self._client.get_cached_access_token(scopes) or self._redeem_refresh_token(scopes, **kwargs)
token = self._client.get_cached_access_token(scopes)
if not token:
token = self._redeem_refresh_token(scopes, **kwargs)
elif self._client.is_refresh(token):
self._redeem_refresh_token(scopes, **kwargs)

if not token:
raise ClientAuthenticationError(
message="No authorization code, cached access token, or refresh token available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
elif self._client.is_refresh(token):
self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def get_token(self, *scopes, **kwargs):
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
elif self._client.is_refresh(token):
self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ def get_token(self, *scopes, **kwargs):

token = self._client.get_cached_access_token(scopes)

if token:
return token
if not token:
token = self._redeem_refresh_token(scopes, **kwargs)
elif self._client.is_refresh(token):
try:
self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Sequence[str], **Any) -> Optional[AccessToken]
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
self._cache = cache or TokenCache()
self._client_id = client_id
self._pipeline = self._build_pipeline(**kwargs)
self._token_refresh_timeout = 30 # default 30s
self._token_refresh_offset = 120 # default 2 min
self._last_refresh_time = int(time.time())

def get_cached_access_token(self, scopes, query=None):
# type: (Sequence[str], Optional[dict]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
for token in tokens:
expires_on = int(token["expires_on"])
if expires_on - 300 > int(time.time()):
if expires_on - 30 > int(time.time()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be

Suggested change
if expires_on - 30 > int(time.time()):
if expires_on - self._token_refresh_timeout > int(time.time()):

or is there some rationale for always using 30 seconds?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not _token_refresh_timeout.

We don't have a clear design for this value but it must be less than _token_refresh_offset (default to 120). Or it will hide the auto refresh feature.

The old one 300 does not meet the requirement so I updated it to 30.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we need an explicit margin here. The 1s margin in if expires_on > int(time.time()) seems okay to me. My reasoning:

  • functionally, this line served to hardcode token_refresh_offset=300
    • if all cached tokens would expire within 300 seconds, this method would return None, prompting the caller to acquire a new token
  • token_refresh_offset will now be observed by callers of this method
  • when a caller enters its refresh window, it should begin trying to acquire a new token
  • while trying to acquire a new token, the caller should return any valid token it has

One bad outcome that could follow is the caller using a token that expires in flight. That request will fail, but the caller's other option was to raise without sending the request at all, because it couldn't acquire a new token. It seems better to try the request, which could after all succeed.

What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is when it is still in token_refresh_retry_timeout time frame.

Extreme case: user gets a token from us which expires in 1s. It is still in token_refresh_retry_timeout time frame so it does not get refreshed.

vs

They get None from us so it forces a refresh.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if the credential is waiting on the retry timeout, it won't try to get a new token, regardless of what it gets back from the cache. Returning None in that case only guarantees the current request will fail, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. if there is no valid token (it returns None), no matter it is in retry timeout window or not, we will try to get one.

Retry timeout only applies to there is A valid token but it is within the refresh offset window.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I overlooked this behavior. Credentials should observe the retry timeout when the cache is empty.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is the behavior we want. If we have no access_token and the first attempt to get one failed, do we really want to hold all requests for 30 seconds before attempting to get one? I think we need to clarify this more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My opinion is if there is no one available, every time user calls our library to get one, we will try it w/o a cool down time.

return AccessToken(token["secret"], expires_on)
return None

Expand All @@ -63,6 +66,19 @@ def get_cached_refresh_tokens(self, scopes):
"""Assumes all cached refresh tokens belong to the same user"""
return self._cache.find(TokenCache.CredentialType.REFRESH_TOKEN, target=list(scopes))

def is_refresh(self, token):
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
# type: (AccessToken) -> bool
""" check if the token needs refresh or not
"""
expires_on = int(token.expires_on)
now = int(time.time())
if expires_on - now > self._token_refresh_offset:
return False
if now - self._last_refresh_time < self._token_refresh_offset:
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
return False
return True


@abc.abstractmethod
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
pass
Expand All @@ -85,6 +101,7 @@ def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):

def _process_response(self, response, request_time):
# type: (PipelineResponse, int) -> AccessToken
self._last_refresh_time = time.time() # no matter succeed or not, update the last refresh time

content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
token = self._client.get_cached_access_token(scopes)
if not token:
token = await self._redeem_refresh_token(scopes, **kwargs)
elif self._client.is_refresh(token):
await self._redeem_refresh_token(scopes, **kwargs)

if not token:
raise ClientAuthenticationError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
elif self._client.is_refresh(token):
await self._client.obtain_token_by_client_certificate(scopes, self._certificate, **kwargs)
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
token = self._client.get_cached_access_token(scopes, query={"client_id": self._client_id})
if not token:
token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
elif self._client.is_refresh(token):
await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,16 @@ async def get_token(self, *scopes, **kwargs):
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_access_token(scopes)
if token:
return token
if not token:
token = await self._redeem_refresh_token(scopes, **kwargs)
elif self._client.is_refresh(token):
try:
await self._redeem_refresh_token(scopes, **kwargs)
except Exception: # pylint: disable=broad-except
pass
return token

async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]":
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
35 changes: 35 additions & 0 deletions sdk/identity/azure-identity/tests/test_token_refresh_offset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time
from azure.identity._internal.aad_client import AadClient
from azure.core.credentials import AccessToken
import pytest

try:
from unittest import mock
except ImportError: # python < 3.3
import mock


def test_if_refresh():
client = AadClient("test", "test")
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved
now = int(time.time())

# do not need refresh
token = AccessToken("token", now + 500)
is_refresh = client.is_refresh(token)
assert not is_refresh

# need refresh
token = AccessToken("token", now + 100)
client._last_refresh_time = now - 500
is_refresh = client.is_refresh(token)
assert is_refresh

# not exceed cool down time, do not refresh
token = AccessToken("token", now + 100)
client._last_refresh_time = now - 5
is_refresh = client.is_refresh(token)
assert not is_refresh
4 changes: 4 additions & 0 deletions sdk/identity/azure-identity/tests/test_vscode_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ def test_cache_refresh_token():


def test_no_obtain_token_if_cached():
def mock_is_refresh(token):
return False

expected_token = AccessToken("token", 42)

mock_client = mock.Mock(spec=object)
mock_client.is_refresh = mock_is_refresh
mock_client.obtain_token_by_refresh_token = mock.Mock(return_value=expected_token)
mock_client.get_cached_access_token = mock.Mock(return_value="VALUE")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ async def test_cache_refresh_token():

@pytest.mark.asyncio
async def test_no_obtain_token_if_cached():
def mock_is_refresh(token):
return False

expected_token = AccessToken("token", 42)

mock_client = mock.Mock(spec=object)
mock_client.is_refresh = mock_is_refresh
token_by_refresh_token = mock.Mock(return_value=expected_token)
mock_client.obtain_token_by_refresh_token = wrap_in_future(token_by_refresh_token)
mock_client.get_cached_access_token = mock.Mock(return_value="VALUE")
Expand Down