Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tables] Add multitenant challenge auth policy support #24278

Merged
merged 10 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdk/tables/azure-data-tables/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Release History

## 12.4.0 (2022-05-10)

### Features Added
- Support for multitenant authentication ([#24278](https://github.com/Azure/azure-sdk-for-python/pull/24278))

## 12.3.0 (2022-03-10)

### Bugs Fixed
Expand Down
104 changes: 102 additions & 2 deletions sdk/tables/azure-data-tables/azure/data/tables/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from urlparse import urlparse # type: ignore

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy

try:
from azure.core.pipeline.transport import AsyncHttpTransport
Expand All @@ -33,7 +33,9 @@
)

if TYPE_CHECKING:
from azure.core.pipeline import PipelineRequest # pylint: disable=ungrouped-imports
from typing import Any
from azure.core.credentials import TokenCredential
from azure.core.pipeline import PipelineResponse, PipelineRequest # pylint: disable=ungrouped-imports


class AzureSigningError(ClientAuthenticationError):
Expand All @@ -44,6 +46,47 @@ class AzureSigningError(ClientAuthenticationError):
"""


class _HttpChallenge(object): # pylint:disable=too-few-public-methods
"""Represents a parsed HTTP WWW-Authentication Bearer challenge from a server."""

def __init__(self, challenge):
if not challenge:
raise ValueError("Challenge cannot be empty")

self._parameters = {}

# Split the scheme ("Bearer") from the challenge parameters
trimmed_challenge = challenge.strip()
split_challenge = trimmed_challenge.split(" ", 1)
trimmed_challenge = split_challenge[1]

# Split trimmed challenge into name=value pairs; these pairs are expected to be split by either commas or spaces
# Values may be surrounded by quotes, which are stripped here
separator = "," if "," in trimmed_challenge else " "
for item in trimmed_challenge.split(separator):
# Process 'name=value' pairs
comps = item.split("=")
if len(comps) == 2:
key = comps[0].strip(' "')
value = comps[1].strip(' "')
if key:
self._parameters[key] = value

# Challenge must specify authorization or authorization_uri
if not self._parameters or (
"authorization" not in self._parameters and "authorization_uri" not in self._parameters
):
raise ValueError("Invalid challenge parameters. `authorization` or `authorization_uri` must be present.")

authorization_uri = self._parameters.get("authorization") or self._parameters.get("authorization_uri") or ""
# the authorization server URI should look something like https://login.windows.net/tenant-id[/oauth2/authorize]
uri_path = urlparse(authorization_uri).path.lstrip("/")
self.tenant_id = uri_path.split("/")[0] or None

self.scope = self._parameters.get("scope") or ""
self.resource = self._parameters.get("resource") or self._parameters.get("resource_id") or ""


# pylint: disable=no-self-use
class SharedKeyCredentialPolicy(SansIOHTTPPolicy):
def __init__(self, credential, is_emulated=False):
Expand Down Expand Up @@ -128,3 +171,60 @@ def _get_canonicalized_resource_query(self, request):
if name == "comp":
return "?comp=" + value
return ""


class BearerTokenChallengePolicy(BearerTokenCredentialPolicy):
"""Adds a bearer token Authorization header to requests, for the tenant provided in authentication challenges.

See https://docs.microsoft.com/azure/active-directory/develop/claims-challenge for documentation on AAD
authentication challenges.

:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword bool discover_tenant: Determines if tenant discovery should be enabled. Defaults to True.
:keyword bool discover_scopes: Determines if scopes from authentication challenges should be provided to token
requests, instead of the scopes given to the policy's constructor, if any are present. Defaults to True.
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

def __init__(
self,
credential: "TokenCredential",
*scopes: str,
discover_tenant: bool = True,
discover_scopes: bool = True,
**kwargs: "Any"
) -> None:
self._discover_tenant = discover_tenant
self._discover_scopes = discover_scopes
super().__init__(credential, *scopes, **kwargs)

def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
"""Authorize request according to an authentication challenge

This method is called when the resource provider responds 401 with a WWW-Authenticate header.

:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
if not self._discover_tenant and not self._discover_scopes:
# We can't discover the tenant or use a different scope; the request will fail because it hasn't changed
return False

try:
challenge = _HttpChallenge(response.http_response.headers.get("WWW-Authenticate"))
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
# if no scopes are included in the challenge, challenge.scope and challenge.resource will both be ''
scope = challenge.scope or challenge.resource + "/.default" if self._discover_scopes else self._scopes
if scope == "/.default":
scope = self._scopes
except ValueError:
return False

if self._discover_tenant:
self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
else:
self.authorize_request(request, scope)
return True
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from azure.core.pipeline.policies import (
RedirectPolicy,
ContentDecodePolicy,
BearerTokenCredentialPolicy,
ProxyPolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
Expand All @@ -46,7 +45,7 @@
_validate_tablename_error
)
from ._models import LocationMode
from ._authentication import SharedKeyCredentialPolicy
from ._authentication import BearerTokenChallengePolicy, SharedKeyCredentialPolicy
from ._policies import (
CosmosPatchTransformPolicy,
StorageHeadersPolicy,
Expand All @@ -58,7 +57,7 @@
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential

_SUPPORTED_API_VERSIONS = ["2019-02-02", "2019-07-07"]
_SUPPORTED_API_VERSIONS = ["2019-02-02", "2019-07-07", "2020-12-06"]


def get_api_version(kwargs, default):
Expand Down Expand Up @@ -249,7 +248,7 @@ def _configure_policies(self, **kwargs):
def _configure_credential(self, credential):
# type: (Any) -> None
if hasattr(credential, "get_token"):
self._credential_policy = BearerTokenCredentialPolicy( # type: ignore
self._credential_policy = BearerTokenChallengePolicy( # type: ignore
credential, STORAGE_OAUTH_SCOPE
annatisch marked this conversation as resolved.
Show resolved Hide resolved
)
elif isinstance(credential, SharedKeyCredentialPolicy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# license information.
# --------------------------------------------------------------------------

VERSION = "12.3.0"
VERSION = "12.4.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING

from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy

from .._authentication import _HttpChallenge

if TYPE_CHECKING:
from typing import Any
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineResponse, PipelineRequest # pylint: disable=ungrouped-imports


class AsyncBearerTokenChallengePolicy(AsyncBearerTokenCredentialPolicy):
"""Adds a bearer token Authorization header to requests, for the tenant provided in authentication challenges.

See https://docs.microsoft.com/azure/active-directory/develop/claims-challenge for documentation on AAD
authentication challenges.

:param credential: The credential.
:type credential: ~azure.core.TokenCredential
:param str scopes: Lets you specify the type of access needed.
:keyword bool discover_tenant: Determines if tenant discovery should be enabled. Defaults to True.
:keyword bool discover_scopes: Determines if scopes from authentication challenges should be provided to token
requests, instead of the scopes given to the policy's constructor, if any are present. Defaults to True.
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""

def __init__(
self,
credential: "AsyncTokenCredential",
*scopes: str,
discover_tenant: bool = True,
discover_scopes: bool = True,
**kwargs: "Any"
) -> None:
self._discover_tenant = discover_tenant
self._discover_scopes = discover_scopes
super().__init__(credential, *scopes, **kwargs)

async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
"""Authorize request according to an authentication challenge

This method is called when the resource provider responds 401 with a WWW-Authenticate header.

:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
"""
if not self._discover_tenant and not self._discover_scopes:
# We can't discover the tenant or use a different scope; the request will fail because it hasn't changed
return False

try:
challenge = _HttpChallenge(response.http_response.headers.get("WWW-Authenticate"))
# azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource
# if no scopes are included in the challenge, challenge.scope and challenge.resource will both be ''
scope = challenge.scope or challenge.resource + "/.default" if self._discover_scopes else self._scopes
if scope == "/.default":
scope = self._scopes
except ValueError:
return False

if self._discover_tenant:
await self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
else:
await self.authorize_request(request, scope)
return True
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential
from azure.core.pipeline.policies import (
ContentDecodePolicy,
AsyncBearerTokenCredentialPolicy,
AsyncRedirectPolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
Expand All @@ -26,6 +25,7 @@
HttpRequest,
)

from ._authentication_async import AsyncBearerTokenChallengePolicy
from .._generated.aio import AzureTable
from .._base_client import AccountHostsMixin, get_api_version, extract_batch_part_metadata
from .._authentication import SharedKeyCredentialPolicy
Expand Down Expand Up @@ -78,7 +78,7 @@ async def close(self) -> None:
def _configure_credential(self, credential):
# type: (Any) -> None
if hasattr(credential, "get_token"):
self._credential_policy = AsyncBearerTokenCredentialPolicy( # type: ignore
self._credential_policy = AsyncBearerTokenChallengePolicy( # type: ignore
credential, STORAGE_OAUTH_SCOPE
)
elif isinstance(credential, SharedKeyCredentialPolicy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AsyncFakeTokenCredential(object):
def __init__(self):
self.token = AccessToken("YOU SHALL NOT PASS", 0)

async def get_token(self, *args):
async def get_token(self, *args, **kwargs):
return self.token


Expand Down
2 changes: 1 addition & 1 deletion sdk/tables/azure-data-tables/tests/_shared/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class FakeTokenCredential(object):
def __init__(self):
self.token = AccessToken("YOU SHALL NOT PASS", 0)

def get_token(self, *args):
def get_token(self, *args, **kwargs):
return self.token


Expand Down
12 changes: 11 additions & 1 deletion sdk/tables/azure-data-tables/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import os

import pytest
from devtools_testutils import add_general_regex_sanitizer, test_proxy
from devtools_testutils import add_general_regex_sanitizer, add_body_key_sanitizer, test_proxy

# fixture needs to be visible from conftest

Expand All @@ -42,3 +44,11 @@ def add_sanitizers(test_proxy):
regex="batch[a-z]*_([0-9a-f]{8}\\b-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-\\b[0-9a-f]{12}\\b)",
group_for_replace="1",
)
# sanitizes tenant ID
tenant_id = os.environ.get("TABLES_TENANT_ID", "00000000-0000-0000-0000-000000000000")
add_general_regex_sanitizer(value="00000000-0000-0000-0000-000000000000", regex=tenant_id)
# sanitizes tenant ID used in test_challenge_auth(_async).py tests
challenge_tenant_id = os.environ.get("CHALLENGE_TABLES_TENANT_ID", "00000000-0000-0000-0000-000000000000")
add_general_regex_sanitizer(value="00000000-0000-0000-0000-000000000000", regex=challenge_tenant_id)
# sanitizes access tokens in response bodies
add_body_key_sanitizer(json_path="$..access_token", value="access_token")
Loading