Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add multi-tenant authentication policy #24019

Closed
wants to merge 16 commits into from
13 changes: 10 additions & 3 deletions sdk/core/azure-core/azure/core/pipeline/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
# --------------------------------------------------------------------------

from ._base import HTTPPolicy, SansIOHTTPPolicy, RequestHistory
from ._authentication import BearerTokenCredentialPolicy, AzureKeyCredentialPolicy, AzureSasCredentialPolicy
from ._authentication import (
AzureKeyCredentialPolicy,
AzureSasCredentialPolicy,
BearerTokenCredentialPolicy,
BearerTokenChallengePolicy,
)
from ._custom_hook import CustomHookPolicy
from ._redirect import RedirectPolicy
from ._retry import RetryPolicy, RetryMode
Expand All @@ -43,9 +48,10 @@
__all__ = [
'HTTPPolicy',
'SansIOHTTPPolicy',
'BearerTokenCredentialPolicy',
'AzureKeyCredentialPolicy',
'AzureSasCredentialPolicy',
'BearerTokenCredentialPolicy',
'BearerTokenChallengePolicy',
'HeadersPolicy',
'UserAgentPolicy',
'NetworkTraceLoggingPolicy',
Expand All @@ -65,12 +71,13 @@

try:
from ._base_async import AsyncHTTPPolicy
from ._authentication_async import AsyncBearerTokenCredentialPolicy
from ._authentication_async import AsyncBearerTokenCredentialPolicy, AsyncBearerTokenChallengePolicy
from ._redirect_async import AsyncRedirectPolicy
from ._retry_async import AsyncRetryPolicy
__all__.extend([
'AsyncHTTPPolicy',
'AsyncBearerTokenCredentialPolicy',
'AsyncBearerTokenChallengePolicy',
'AsyncRedirectPolicy',
'AsyncRetryPolicy'
])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# license information.
# -------------------------------------------------------------------------
import time
import urllib.parse as parse
import six

from . import HTTPPolicy, SansIOHTTPPolicy
from . import HTTPPolicy, SansIOHTTPPolicy, _http_challenge_cache as ChallengeCache
from ...exceptions import ServiceRequestError

try:
Expand All @@ -21,6 +22,53 @@
from azure.core.pipeline import PipelineRequest, PipelineResponse


class HttpChallenge(object): # pylint:disable=too-few-public-methods
"""Represents a parsed HTTP WWW-Authentication Bearer challenge from a server."""

def __init__(self, request_uri, challenge):
if not request_uri:
raise ValueError("URL cannot be empty")
if not challenge:
raise ValueError("Challenge cannot be empty")

uri = parse.urlparse(request_uri)
if not uri.netloc:
raise ValueError("request_uri must be an absolute URI")
self.source_authority = uri.netloc
self._parameters = {}

# Split the scheme ("Bearer") from the challenge parameters
trimmed_challenge = challenge.strip()
split_challenge = trimmed_challenge.split(" ", 1)
trimmed_challenge = split_challenge[1]

# Split trimmed challenge into name=value pairs; these pairs are expected to be split by either commas or spaces
# Values may be surrounded by quotes, which are stripped here
annatisch marked this conversation as resolved.
Show resolved Hide resolved
separator = "," if "," in trimmed_challenge else " "
for item in trimmed_challenge.split(separator):
# Process 'name=value' pairs
comps = item.split("=")
if len(comps) == 2:
key = comps[0].strip(' "')
value = comps[1].strip(' "')
if key:
self._parameters[key] = value

# Challenge must specify authorization or authorization_uri
if not self._parameters or (
"authorization" not in self._parameters and "authorization_uri" not in self._parameters
):
raise ValueError("Invalid challenge parameters. `authorization` or `authorization_uri` must be present.")

authorization_uri = self._parameters.get("authorization") or self._parameters.get("authorization_uri") or ""
# the authorization server URI should look something like https://login.windows.net/tenant-id[/oauth2/authorize]
uri_path = parse.urlparse(authorization_uri).path.lstrip("/")
self.tenant_id = uri_path.split("/")[0] or None

self.scope = self._parameters.get("scope") or ""
self.resource = self._parameters.get("resource") or self._parameters.get("resource_id") or ""


# pylint:disable=too-few-public-methods
class _BearerTokenCredentialPolicyBase(object):
"""Base class for a Bearer Token Credential Policy.
Expand Down Expand Up @@ -172,6 +220,89 @@ def on_exception(self, request):
return


class BearerTokenChallengePolicy(BearerTokenCredentialPolicy):
"""Adds a bearer token Authorization header to requests, for the tenant provided in authentication challenges.

See https://docs.microsoft.com/azure/active-directory/develop/claims-challenge for documentation on AAD
authentication challenges.

:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword bool discover_tenant: Determines if tenant discovery should be enabled. Defaults to True.
:keyword bool discover_scopes: Determines if scopes from authentication challenges should be provided to token
requests, instead of the scopes given to the policy's constructor, if any are present. Defaults to True.
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

def __init__(
self,
credential: "TokenCredential",
*scopes: str,
discover_tenant: bool = True,
discover_scopes: bool = True,
**kwargs: "Any"
) -> None:
self._discover_tenant = discover_tenant
self._discover_scopes = discover_scopes
self.challenge_cache = ChallengeCache
super().__init__(credential, *scopes, **kwargs)

def on_request(self, request):
# type: (PipelineRequest) -> None
"""Called before the policy sends a request.

The base implementation authorizes the request with a bearer token.

:param ~azure.core.pipeline.PipelineRequest request: the request
"""
self._enforce_https(request)
challenge = self.challenge_cache.get_challenge_for_url(request.http_request.url)

if self._token is None or self._need_new_token:
self._authorize_request_with_challenge(request, challenge)
else:
self._update_headers(request.http_request.headers, self._token.token)

def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
"""Authorize request according to an authentication challenge

This method is called when the resource provider responds 401 with a WWW-Authenticate header.

:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
if not self._discover_tenant and not self._discover_scopes:
# We can't discover the tenant or use a different scope; the request will fail because it hasn't changed
return False

try:
challenge = HttpChallenge(request.http_request.url, response.http_response.headers.get("WWW-Authenticate"))
self.challenge_cache.set_challenge_for_url(request.http_request.url, challenge)
except ValueError:
return False

self._authorize_request_with_challenge(request, challenge)
return True

def _authorize_request_with_challenge(self, request, challenge):
if not challenge:
self.authorize_request(request, self._scopes)
return

# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
# if no scopes are included in the challenge, challenge.scope and challenge.resource will both be ''
scope = challenge.scope or challenge.resource + "/.default" if self._discover_scopes else self._scopes
if scope == "/.default":
scope = self._scopes

if self._discover_tenant:
self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
else:
self.authorize_request(request, scope)


class AzureKeyCredentialPolicy(SansIOHTTPPolicy):
"""Adds a key header for the provided credential.

Expand All @@ -180,6 +311,7 @@ class AzureKeyCredentialPolicy(SansIOHTTPPolicy):
:param str name: The name of the key header used for the credential.
:raises: ValueError or TypeError
"""

def __init__(self, credential, name, **kwargs): # pylint: disable=unused-argument
# type: (AzureKeyCredential, str, **Any) -> None
super(AzureKeyCredentialPolicy, self).__init__()
Expand All @@ -201,6 +333,7 @@ class AzureSasCredentialPolicy(SansIOHTTPPolicy):
:type credential: ~azure.core.credentials.AzureSasCredential
:raises: ValueError or TypeError
"""

def __init__(self, credential, **kwargs): # pylint: disable=unused-argument
# type: (AzureSasCredential, **Any) -> None
super(AzureSasCredentialPolicy, self).__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from typing import TYPE_CHECKING

from azure.core.pipeline.policies import AsyncHTTPPolicy
from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase
from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase, HttpChallenge

from . import _http_challenge_cache as ChallengeCache
from .._tools_async import await_result

if TYPE_CHECKING:
Expand Down Expand Up @@ -128,3 +129,86 @@ def on_exception(self, request: "PipelineRequest") -> None:

def _need_new_token(self) -> bool:
return not self._token or self._token.expires_on - time.time() < 300


class AsyncBearerTokenChallengePolicy(AsyncBearerTokenCredentialPolicy):
"""Adds a bearer token Authorization header to requests, for the tenant provided in authentication challenges.

See https://docs.microsoft.com/azure/active-directory/develop/claims-challenge for documentation on AAD
authentication challenges.

:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword bool discover_tenant: Determines if tenant discovery should be enabled. Defaults to True.
:keyword bool discover_scopes: Determines if scopes from authentication challenges should be provided to token
requests, instead of the scopes given to the policy's constructor, if any are present. Defaults to True.
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

def __init__(
self,
credential: "AsyncTokenCredential",
*scopes: str,
discover_tenant: bool = True,
discover_scopes: bool = True,
**kwargs: "Any"
) -> None:
self._discover_tenant = discover_tenant
self._discover_scopes = discover_scopes
self.challenge_cache = ChallengeCache
super().__init__(credential, *scopes, **kwargs)

async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
"""Adds a bearer token Authorization header to request and sends request to next policy.

:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
_BearerTokenCredentialPolicyBase._enforce_https(request) # pylint:disable=protected-access
challenge = self.challenge_cache.get_challenge_for_url(request.http_request.url)

if self._token is None or self._need_new_token():
# We don't acquire the lock here because _authorize_request_with_challenge does via authorize_request
await self._authorize_request_with_challenge(request, challenge)
else:
request.http_request.headers["Authorization"] = "Bearer " + self._token.token

async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
"""Authorize request according to an authentication challenge

This method is called when the resource provider responds 401 with a WWW-Authenticate header.

:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
if not self._discover_tenant and not self._discover_scopes:
# We can't discover the tenant or use a different scope; the request will fail because it hasn't changed
return False

try:
challenge = HttpChallenge(request.http_request.url, response.http_response.headers.get("WWW-Authenticate"))
self.challenge_cache.set_challenge_for_url(request.http_request.url, challenge)
except ValueError:
return False

await self._authorize_request_with_challenge(request, challenge)
return True

async def _authorize_request_with_challenge(self, request, challenge):
if not challenge:
await self.authorize_request(request, self._scopes)
return

# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
# if no scopes are included in the challenge, challenge.scope and challenge.resource will both be ''
scope = challenge.scope or challenge.resource + "/.default" if self._discover_scopes else self._scopes
if scope == "/.default":
scope = self._scopes

if self._discover_tenant:
await self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
else:
await self.authorize_request(request, scope)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import threading

import urllib.parse as parse

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint: disable=unused-import
from typing import Dict
from ._authentication import HttpChallenge


_cache = {} # type: Dict[str, HttpChallenge]
_lock = threading.Lock()


def get_challenge_for_url(url):
"""Gets the challenge for the cached URL.

:param url: the URL the challenge is cached for.
:rtype: HttpBearerChallenge
"""
key = _get_cache_key(url)

with _lock:
return _cache.get(key)


def _get_cache_key(url):
"""Use the URL's netloc as cache key except when the URL specifies the default port for its scheme. In that case
use the netloc without the port. That is to say, https://foo.bar and https://foo.bar:443 are considered equivalent.

This equivalency prevents an unnecessary challenge when using Key Vault's paging API. The Key Vault client doesn't
specify ports, but Key Vault's next page links do, so a redundant challenge would otherwise be executed when the
client requests the next page.
"""
parsed = parse.urlparse(url)
if parsed.scheme == "https" and parsed.port == 443:
return parsed.netloc[:-4]
return parsed.netloc


def remove_challenge_for_url(url):
"""Removes the cached challenge for the specified URL.

:param url: the URL for which to remove the cached challenge
"""
url = parse.urlparse(url)

with _lock:
del _cache[url.netloc]


def set_challenge_for_url(url, challenge):
"""Caches the challenge for the specified URL.

:param url: the URL for which to cache the challenge
:param challenge: the challenge to cache
"""
src_url = parse.urlparse(url)
if src_url.netloc != challenge.source_authority:
raise ValueError("Source URL and Challenge URL do not match")

with _lock:
_cache[src_url.netloc] = challenge


def clear():
"""Clears the cache."""
with _lock:
_cache.clear()
Loading