Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ClientSecretCredential to use AadClient #11718

Merged
merged 3 commits into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from .._authn_client import AuthnClient
from .._base import ClientSecretCredentialBase
from .._internal import AadClient, ClientSecretCredentialBase

try:
from typing import TYPE_CHECKING
Expand All @@ -28,12 +27,7 @@ class ClientSecretCredential(ClientSecretCredentialBase):
defines authorities for other clouds.
"""

def __init__(self, tenant_id, client_id, client_secret, **kwargs):
# type: (str, str, str, **Any) -> None
super(ClientSecretCredential, self).__init__(tenant_id, client_id, client_secret, **kwargs)
self._client = AuthnClient(tenant=tenant_id, **kwargs)

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Request an access token for `scopes`.

Expand All @@ -48,8 +42,10 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_token(scopes)
token = self._client.get_cached_access_token(scopes)
if not token:
data = dict(self._form_data, scope=" ".join(scopes))
token = self._client.request_token(scopes, form_data=data)
token = self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
return AadClient(tenant_id, client_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_default_authority():
from .auth_code_redirect_handler import AuthCodeRedirectServer
from .aadclient_certificate import AadClientCertificate
from .certificate_credential_base import CertificateCredentialBase
from .client_secret_credential_base import ClientSecretCredentialBase
from .exception_wrapper import wrap_exceptions
from .msal_credentials import ConfidentialClientCredential, InteractiveCredential, PublicClientCredential
from .msal_transport_adapter import MsalTransportAdapter, MsalTransportResponse
Expand All @@ -60,6 +61,7 @@ def _scopes_to_resource(*scopes):
"AuthCodeRedirectServer",
"AadClientCertificate",
"CertificateCredentialBase",
"ClientSecretCredentialBase",
"ConfidentialClientCredential",
"get_default_authority",
"InteractiveCredential",
Expand Down
12 changes: 10 additions & 2 deletions sdk/identity/azure-identity/azure/identity/_internal/aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

class AadClient(AadClientBase):
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
# type: (str, str, Sequence[str], Optional[str], **Any) -> AccessToken
# type: (Sequence[str], str, str, Optional[str], **Any) -> AccessToken
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret
)
Expand All @@ -50,8 +50,16 @@ def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
# type: (Sequence[str], str, **Any) -> AccessToken
request = self._get_client_secret_request(scopes, secret)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (str, Sequence[str], **Any) -> AccessToken
# type: (Sequence[str], str, **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token)
now = int(time.time())
response = self._pipeline.run(request, stream=False, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_
def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
pass
Expand Down Expand Up @@ -131,6 +135,19 @@ def _get_client_certificate_request(self, scopes, certificate):
)
return request

def _get_client_secret_request(self, scopes, secret):
# type: (Sequence[str], str) -> HttpRequest
data = {
"client_id": self._client_id,
"client_secret": secret,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
}
request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
return request

def _get_jwt_assertion(self, certificate):
# type: (AadClientCertificate) -> str
now = int(time.time())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,33 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
from typing import TYPE_CHECKING

try:
ABC = abc.ABC
except AttributeError: # Python 2.7, abc exists, but not ABC
except AttributeError: # Python 2.7
ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()}) # type: ignore

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Optional, Union

# pylint:disable=unused-import,ungrouped-imports
from typing import Any

class ClientSecretCredentialBase(object):
"""Sans I/O base for client secret credentials"""

def __init__(self, tenant_id, client_id, secret, **kwargs): # pylint:disable=unused-argument
class ClientSecretCredentialBase(ABC):
def __init__(self, tenant_id, client_id, client_secret, **kwargs):
# type: (str, str, str, **Any) -> None
if not client_id:
raise ValueError("client_id should be the id of an Azure Active Directory application")
if not secret:
if not client_secret:
raise ValueError("secret should be an Azure Active Directory application's client secret")
if not tenant_id:
raise ValueError(
"tenant_id should be an Azure Active Directory tenant's id (also called its 'directory id')"
)
self._form_data = {"client_id": client_id, "client_secret": secret, "grant_type": "client_credentials"}
super(ClientSecretCredentialBase, self).__init__()

self._client = self._get_auth_client(tenant_id, client_id, **kwargs)
self._secret = client_secret

@abc.abstractmethod
def _get_auth_client(self, tenant_id, client_id, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from typing import TYPE_CHECKING

from .base import AsyncCredentialBase
from .._authn_client import AsyncAuthnClient
from ..._base import ClientSecretCredentialBase
from .._internal import AadClient
from ..._internal import ClientSecretCredentialBase

if TYPE_CHECKING:
from typing import Any
from azure.core.credentials import AccessToken


class ClientSecretCredential(ClientSecretCredentialBase, AsyncCredentialBase):
class ClientSecretCredential(AsyncCredentialBase, ClientSecretCredentialBase):
"""Authenticates as a service principal using a client ID and client secret.

:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
Expand All @@ -25,10 +25,6 @@ class ClientSecretCredential(ClientSecretCredentialBase, AsyncCredentialBase):
defines authorities for other clouds.
"""

def __init__(self, tenant_id: str, client_id: str, client_secret: str, **kwargs: "Any") -> None:
super(ClientSecretCredential, self).__init__(tenant_id, client_id, client_secret, **kwargs)
self._client = AsyncAuthnClient(tenant=tenant_id, **kwargs)

async def __aenter__(self):
await self._client.__aenter__()
return self
Expand All @@ -38,7 +34,7 @@ async def close(self):

await self._client.__aexit__()

async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # pylint:disable=unused-argument
async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
"""Asynchronously request an access token for `scopes`.

.. note:: This method is called by Azure SDK clients. It isn't intended for use in application code.
Expand All @@ -52,8 +48,10 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
if not scopes:
raise ValueError("'get_token' requires at least one scope")

token = self._client.get_cached_token(scopes)
token = self._client.get_cached_access_token(scopes)
if not token:
data = dict(self._form_data, scope=" ".join(scopes))
token = await self._client.request_token(scopes, form_data=data)
return token # type: ignore
token = await self._client.obtain_token_by_client_secret(scopes, self._secret, **kwargs)
return token

def _get_auth_client(self, tenant_id, client_id, **kwargs):
return AadClient(tenant_id, client_id, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

async def obtain_token_by_client_secret(
self, scopes: "Sequence[str]", secret: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_client_secret_request(scopes, secret)
now = int(time.time())
response = await self._pipeline.run(request, **kwargs)
content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)
return self._process_response(response=content, scopes=scopes, now=now)

async def obtain_token_by_refresh_token(
self, scopes: "Sequence[str]", refresh_token: str, **kwargs: "Any"
) -> "AccessToken":
Expand Down
24 changes: 24 additions & 0 deletions sdk/identity/azure-identity/tests/test_aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,30 @@ def send(request, **_):
assert transport.send.call_count == 1


def test_client_secret():
tenant_id = "tenant-id"
client_id = "client-id"
scope = "scope"
secret = "refresh-token"
access_token = "***"

def send(request, **_):
assert request.data["client_id"] == client_id
assert request.data["client_secret"] == secret
assert request.data["grant_type"] == "client_credentials"
assert request.data["scope"] == scope

return mock_response(json_payload={"access_token": access_token, "expires_in": 42})

transport = Mock(send=Mock(wraps=send))

client = AadClient(tenant_id, client_id, transport=transport)
token = client.obtain_token_by_client_secret(scopes=(scope,), secret=secret)

assert token.token == access_token
assert transport.send.call_count == 1


def test_refresh_token():
tenant_id = "tenant-id"
client_id = "client-id"
Expand Down
24 changes: 24 additions & 0 deletions sdk/identity/azure-identity/tests/test_aad_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,30 @@ async def send(request, **_):
assert transport.send.call_count == 1


async def test_client_secret():
tenant_id = "tenant-id"
client_id = "client-id"
scope = "scope"
secret = "refresh-token"
access_token = "***"

async def send(request, **_):
assert request.data["client_id"] == client_id
assert request.data["client_secret"] == secret
assert request.data["grant_type"] == "client_credentials"
assert request.data["scope"] == scope

return mock_response(json_payload={"access_token": access_token, "expires_in": 42})

transport = Mock(send=Mock(wraps=send))

client = AadClient(tenant_id, client_id, transport=transport)
token = await client.obtain_token_by_client_secret(scopes=(scope,), secret=secret)

assert token.token == access_token
assert transport.send.call_count == 1


async def test_refresh_token():
tenant_id = "tenant-id"
client_id = "client-id"
Expand Down