Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Commit

Permalink
Adapt msrestazure to session improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
lmazuel committed Mar 30, 2018
1 parent a0b015a commit ee3bf75
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 71 deletions.
135 changes: 85 additions & 50 deletions msrestazure/azure_active_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,6 @@ def _parse_token(self):
if self.token.get('expires_at'):
countdown = float(self.token['expires_at']) - time.time()
self.token['expires_in'] = countdown
kwargs = {}
if self.token.get('refresh_token'):
kwargs['auto_refresh_url'] = self.token_uri
kwargs['auto_refresh_kwargs'] = {'client_id': self.id,
'resource': self.resource}
kwargs['token_updater'] = self._default_token_cache
return kwargs

def _default_token_cache(self, token):
"""Store token for future sessions.
Expand Down Expand Up @@ -228,23 +221,19 @@ def _retrieve_stored_token(self):
self.token = ast.literal_eval(str(token))
self.signed_session()

def signed_session(self):
def signed_session(self, session=None):
"""Create token-friendly Requests session, using auto-refresh.
Used internally when a request is made.
:rtype: requests_oauthlib.OAuth2Session
:raises: TokenExpiredError if token can no longer be refreshed.
"""
kwargs = self._parse_token()
try:
new_session = oauth.OAuth2Session(
self.id,
token=self.token,
**kwargs)
return new_session
except TokenExpiredError as err:
raise_with_traceback(Expired, "", err)
If a session object is provided, configure it directly. Otherwise,
create a new session and return it.
:param session: The session to configure for authentication
:type session: requests.Session
"""
self._parse_token()
return super(AADMixin, self).signed_session(session)

def clear_cached_token(self):
"""Clear any stored tokens.
Expand All @@ -256,25 +245,6 @@ def clear_cached_token(self):
raise_with_traceback(KeyError, "Unable to clear token.")


class AADRefreshMixin(object):
"""Additional token refresh logic.
"""

def refresh_session(self):
"""Return updated session if token has expired, attempts to
refresh using newly acquired token.
:rtype: requests.Session.
"""
if self.token.get('refresh_token'):
try:
return self.signed_session()
except Expired:
pass
self.set_token()
return self.signed_session()


class AADTokenCredentials(AADMixin):
"""
Credentials objects for AAD token retrieved through external process
Expand Down Expand Up @@ -315,12 +285,12 @@ def retrieve_session(cls, client_id=None):
"""Create AADTokenCredentials from a cached token if it has not
yet expired.
"""
session = cls(None, None, client_id=client_id, cached=True)
session = cls(None, client_id=client_id, cached=True)
session._retrieve_stored_token()
return session


class UserPassCredentials(AADRefreshMixin, AADMixin):
class UserPassCredentials(AADMixin):
"""Credentials object for Headless Authentication,
i.e. AAD authentication via username and password.
Expand Down Expand Up @@ -395,7 +365,8 @@ def set_token(self):
if self.secret:
optional['client_secret'] = self.secret
try:
token = session.fetch_token(self.token_uri, client_id=self.id,
token = session.fetch_token(self.token_uri,
client_id=self.id,
username=self.username,
password=self.password,
resource=self.resource,
Expand All @@ -407,9 +378,42 @@ def set_token(self):
raise_with_traceback(AuthenticationError, "", err)

self.token = token
self._default_token_cache(self.token)

def refresh_session(self, session=None):
"""Return updated session if token has expired, attempts to
refresh using newly acquired token.
If a session object is provided, configure it directly. Otherwise,
create a new session and return it.
:param session: The session to configure for authentication
:type session: requests.Session
:rtype: requests.Session.
"""
with self._setup_session() as session:
optional = {}
if self.secret:
optional['client_secret'] = self.secret
try:
token = session.refresh_token(self.token_uri,
client_id=self.id,
username=self.username,
password=self.password,
resource=self.resource,
verify=self.verify,
proxies=self.proxies,
timeout=self.timeout,
**optional)
except (RequestException, OAuth2Error, InvalidGrantError) as err:
raise_with_traceback(AuthenticationError, "", err)

self.token = token
self._default_token_cache(self.token)
return self.signed_session(session)


class ServicePrincipalCredentials(AADRefreshMixin, AADMixin):
class ServicePrincipalCredentials(AADMixin):
"""Credentials object for Service Principle Authentication.
Authenticates via a Client ID and Secret.
Expand Down Expand Up @@ -466,7 +470,8 @@ def set_token(self):
"""
with self._setup_session() as session:
try:
token = session.fetch_token(self.token_uri, client_id=self.id,
token = session.fetch_token(self.token_uri,
client_id=self.id,
resource=self.resource,
client_secret=self.secret,
response_type="client_credentials",
Expand All @@ -477,6 +482,24 @@ def set_token(self):
raise_with_traceback(AuthenticationError, "", err)
else:
self.token = token
self._default_token_cache(self.token)

def refresh_session(self, session=None):
"""Alias to signed_session().
SP flow does not contain refresh_token, so this method is just asking a new
token to AD.
If a session object is provided, configure it directly. Otherwise,
create a new session and return it.
:param session: The session to configure for authentication
:type session: requests.Session
:rtype: requests.Session.
"""
self.set_token()
return self.signed_session(session)


# For backward compatibility of import, but I doubt someone uses that...
class InteractiveCredentials(object):
Expand Down Expand Up @@ -540,14 +563,17 @@ def __init__(self, adal_method, *args, **kwargs):
self._args = args
self._kwargs = kwargs

def signed_session(self):
"""Get a signed session for requests.
def signed_session(self, session=None):
"""Create requests session with any required auth headers applied.
Usually called by the Azure SDKs for you to authenticate queries.
If a session object is provided, configure it directly. Otherwise,
create a new session and return it.
:param session: The session to configure for authentication
:type session: requests.Session
:rtype: requests.Session
"""
session = super(AdalAuthentication, self).signed_session()
session = super(AdalAuthentication, self).signed_session(session)

try:
raw_token = self._adal_method(*self._args, **self._kwargs)
Expand Down Expand Up @@ -691,10 +717,19 @@ def set_token(self):
token_entry = self._vm_msi.get_token()
self.scheme, self.token = token_entry['token_type'], token_entry

def signed_session(self):
def signed_session(self, session=None):
"""Create requests session with any required auth headers applied.
If a session object is provided, configure it directly. Otherwise,
create a new session and return it.
:param session: The session to configure for authentication
:type session: requests.Session
:rtype: requests.Session
"""
# Token cache is handled by the VM extension, call each time to avoid expiration
self.set_token()
return super(MSIAuthentication, self).signed_session()
return super(MSIAuthentication, self).signed_session(session)


class _ImdsTokenProvider(object):
Expand Down
21 changes: 0 additions & 21 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,27 +193,6 @@ def test_credentials_retrieve_session(self, mock_retrieve):
with self.assertRaises(TokenExpiredError):
ServicePrincipalCredentials.retrieve_session("client_id")

@mock.patch('msrestazure.azure_active_directory.oauth')
def test_credentials_signed_session(self, mock_requests):

creds = mock.create_autospec(ServicePrincipalCredentials)
creds._parse_token = lambda: AADMixin._parse_token(creds)
creds.id = 'client_id'
creds.token_uri = "token_uri"
creds.resource = "resource"

creds.token = {'expires_at':'1',
'expires_in':'2',
'refresh_token':"test"}

AADMixin.signed_session(creds)
mock_requests.OAuth2Session.assert_called_with(
'client_id',
token=creds.token,
auto_refresh_url='token_uri',
auto_refresh_kwargs={'client_id':'client_id', 'resource':'resource'},
token_updater=creds._default_token_cache)

def test_service_principal(self):

creds = mock.create_autospec(ServicePrincipalCredentials)
Expand Down

0 comments on commit ee3bf75

Please sign in to comment.