Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Credentials accept tenant_id argument to get_token #19602

Merged
merged 19 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
AadClient allows multitenant auth
  • Loading branch information
chlowell committed Jul 6, 2021
commit e5f2502804a5ea13bf6b491949e1e3fff526de4a
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,29 @@ class AadClient(AadClientBase):
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
# type: (Iterable[str], str, str, Optional[str], **Any) -> AccessToken
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs
)
now = int(time.time())
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken
request = self._get_client_certificate_request(scopes, certificate)
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
now = int(time.time())
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_client_secret_request(scopes, secret)
request = self._get_client_secret_request(scopes, secret, **kwargs)
now = int(time.time())
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token)
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
now = int(time.time())
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from azure.core.exceptions import ClientAuthenticationError
from . import get_default_authority, normalize_authority
from .._constants import DEFAULT_TOKEN_REFRESH_RETRY_DELAY, DEFAULT_REFRESH_OFFSET
from .._internal import resolve_tenant

try:
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -44,20 +45,30 @@
class AadClientBase(ABC):
_POST = ["POST"]

def __init__(self, tenant_id, client_id, authority=None, cache=None, **kwargs):
# type: (str, str, Optional[str], Optional[TokenCache], **Any) -> None
authority = normalize_authority(authority) if authority else get_default_authority()
self._token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token"))
def __init__(
self, tenant_id, client_id, authority=None, cache=None, allow_multitenant_authentication=False, **kwargs
):
# type: (str, str, Optional[str], Optional[TokenCache], bool, **Any) -> None
self._authority = normalize_authority(authority) if authority else get_default_authority()

self._tenant_id = tenant_id
self._allow_multitenant = allow_multitenant_authentication

self._cache = cache or TokenCache()
self._client_id = client_id
self._pipeline = self._build_pipeline(**kwargs)
self._token_refresh_retry_delay = DEFAULT_TOKEN_REFRESH_RETRY_DELAY
self._token_refresh_offset = DEFAULT_REFRESH_OFFSET
self._last_refresh_time = 0

def get_cached_access_token(self, scopes, query=None):
# type: (Iterable[str], Optional[dict]) -> Optional[AccessToken]
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=list(scopes), query=query)
def get_cached_access_token(self, scopes, **kwargs):
# type: (Iterable[str], **Any) -> Optional[AccessToken]
tenant = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs)
tokens = self._cache.find(
TokenCache.CredentialType.ACCESS_TOKEN,
target=list(scopes),
query={"client_id": self._client_id, "realm": tenant},
)
for token in tokens:
expires_on = int(token["expires_on"])
if expires_on > int(time.time()):
Expand Down Expand Up @@ -91,7 +102,7 @@ def _build_pipeline(self, config=None, policies=None, transport=None, **kwargs):

def _process_response(self, response, request_time):
# type: (PipelineResponse, int) -> AccessToken
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time
self._last_refresh_time = request_time # no matter succeed or not, update the last refresh time

content = ContentDecodePolicy.deserialize_from_http_generics(response.http_response)

Expand Down Expand Up @@ -133,17 +144,18 @@ def _process_response(self, response, request_time):
# caching is the final step because 'add' mutates 'content'
self._cache.add(
event={
"client_id": self._client_id,
"response": content,
"scope": response.http_request.body["scope"].split(),
"client_id": self._client_id,
"token_endpoint": response.http_request.url,
},
now=request_time,
)

return token

def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None):
# type: (Iterable[str], str, str, Optional[str]) -> HttpRequest
def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
# type: (Iterable[str], str, str, Optional[str], **Any) -> HttpRequest
data = {
"client_id": self._client_id,
"code": code,
Expand All @@ -154,14 +166,13 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None)
if client_secret:
data["client_secret"] = client_secret

request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
request = self._post(data, **kwargs)
return request

def _get_client_certificate_request(self, scopes, certificate):
# type: (Iterable[str], AadClientCertificate) -> HttpRequest
assertion = self._get_jwt_assertion(certificate)
def _get_client_certificate_request(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
audience = self._get_token_url(**kwargs)
assertion = self._get_jwt_assertion(certificate, audience)
data = {
"client_assertion": assertion,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
Expand All @@ -170,26 +181,22 @@ def _get_client_certificate_request(self, scopes, certificate):
"scope": " ".join(scopes),
}

request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
request = self._post(data, **kwargs)
return request

def _get_client_secret_request(self, scopes, secret):
# type: (Iterable[str], str) -> HttpRequest
def _get_client_secret_request(self, scopes, secret, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"client_id": self._client_id,
"client_secret": secret,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
}
request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
request = self._post(data, **kwargs)
return request

def _get_jwt_assertion(self, certificate):
# type: (AadClientCertificate) -> str
def _get_jwt_assertion(self, certificate, audience):
# type: (AadClientCertificate, str) -> str
now = int(time.time())
header = six.ensure_binary(
json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8"
Expand All @@ -198,7 +205,7 @@ def _get_jwt_assertion(self, certificate):
json.dumps(
{
"jti": str(uuid4()),
"aud": self._token_endpoint,
"aud": audience,
"iss": self._client_id,
"sub": self._client_id,
"nbf": now,
Expand All @@ -213,20 +220,28 @@ def _get_jwt_assertion(self, certificate):

return jwt_bytes.decode("utf-8")

def _get_refresh_token_request(self, scopes, refresh_token):
# type: (Iterable[str], str) -> HttpRequest
def _get_refresh_token_request(self, scopes, refresh_token, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"scope": " ".join(scopes),
"client_id": self._client_id,
"client_info": 1, # request AAD include home_account_id in its response
}
request = HttpRequest(
"POST", self._token_endpoint, headers={"Content-Type": "application/x-www-form-urlencoded"}, data=data
)
request = self._post(data, **kwargs)
return request

def _get_token_url(self, **kwargs):
# type: (**Any) -> str
tenant = resolve_tenant(self._tenant_id, self._allow_multitenant, **kwargs)
return "/".join((self._authority, tenant, "oauth2/v2.0/token"))

def _post(self, data, **kwargs):
# type: (dict, **Any) -> HttpRequest
url = self._get_token_url(**kwargs)
return HttpRequest("POST", url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"})


def _scrub_secrets(response):
# type: (dict) -> None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,31 @@ async def obtain_token_by_authorization_code(
**kwargs: "Any"
) -> "AccessToken":
request = self._get_auth_code_request(
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret
scopes=scopes, code=code, redirect_uri=redirect_uri, client_secret=client_secret, **kwargs
)
now = int(time.time())
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken
request = self._get_client_certificate_request(scopes, certificate)
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_client_secret(
self, scopes: "Iterable[str]", secret: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_client_secret_request(scopes, secret)
request = self._get_client_secret_request(scopes, secret, **kwargs)
now = int(time.time())
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_refresh_token(
self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_refresh_token_request(scopes, refresh_token)
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
now = int(time.time())
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)
Expand Down