diff --git a/sdk/identity/azure-identity/tests/helpers.py b/sdk/identity/azure-identity/tests/helpers.py index 2805c54f3135..b1be22602c72 100644 --- a/sdk/identity/azure-identity/tests/helpers.py +++ b/sdk/identity/azure-identity/tests/helpers.py @@ -153,8 +153,14 @@ def mock_response(status_code=200, headers=None, json_payload=None): def get_discovery_response(endpoint="https://a/b"): + """Get a mock response containing the values MSAL requires from tenant and instance discovery. + + The response is incomplete and its values aren't necessarily valid, particularly for instance discovery, but it's + sufficient. MSAL will send token requests to "{endpoint}/oauth2/v2.0/token_endpoint" after receiving a tenant + discovery response created by this method. + """ aad_metadata_endpoint_names = ("authorization_endpoint", "token_endpoint", "tenant_discovery_endpoint") - payload = {name: endpoint for name in aad_metadata_endpoint_names} + payload = {name: endpoint + "/oauth2/v2.0/" + name for name in aad_metadata_endpoint_names} payload["metadata"] = "" return mock_response(json_payload=payload) diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 4a313608a3b1..9f798d50c4fe 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -56,7 +56,11 @@ def test_exceptions_do_not_expose_secrets(): fns = [ functools.partial(client.obtain_token_by_authorization_code, "code", "uri", "scope"), - functools.partial(client.obtain_token_by_refresh_token, "refresh token", ("scope"),), + functools.partial( + client.obtain_token_by_refresh_token, + "refresh token", + ("scope"), + ), ] def assert_secrets_not_exposed(): @@ -233,3 +237,76 @@ def test_retries_token_requests(): client.obtain_token_by_refresh_token("", "") assert transport.send.call_count > 1 transport.send.reset_mock() + + +def test_shared_cache(): + """The client should return only tokens associated with its own client_id""" + + client_id_a = "client-id-a" + client_id_b = "client-id-b" + scope = "scope" + expected_token = "***" + tenant_id = "tenant" + authority = "https://localhost/" + tenant_id + + cache = TokenCache() + cache.add( + { + "response": build_aad_response(access_token=expected_token), + "client_id": client_id_a, + "scope": [scope], + "token_endpoint": "/".join((authority, tenant_id, "oauth2/v2.0/token")), + } + ) + + common_args = dict(authority=authority, cache=cache, tenant_id=tenant_id) + client_a = AadClient(client_id=client_id_a, **common_args) + client_b = AadClient(client_id=client_id_b, **common_args) + + # A has a cached token + token = client_a.get_cached_access_token([scope]) + assert token.token == expected_token + + # which B shouldn't return + assert client_b.get_cached_access_token([scope]) is None + + +def test_multitenant_cache(): + client_id = "client-id" + scope = "scope" + expected_token = "***" + tenant_a = "tenant-a" + tenant_b = "tenant-b" + tenant_c = "tenant-c" + authority = "https://localhost/" + tenant_a + + cache = TokenCache() + cache.add( + { + "response": build_aad_response(access_token=expected_token), + "client_id": client_id, + "scope": [scope], + "token_endpoint": "/".join((authority, tenant_a, "oauth2/v2.0/token")), + } + ) + + common_args = dict(authority=authority, cache=cache, client_id=client_id) + client_a = AadClient(tenant_id=tenant_a, **common_args) + client_b = AadClient(tenant_id=tenant_b, **common_args) + + # A has a cached token + token = client_a.get_cached_access_token([scope]) + assert token.token == expected_token + + # which B shouldn't return + assert client_b.get_cached_access_token([scope]) is None + + # but C allows multitenant auth and should therefore return the token from tenant_a when appropriate + client_c = AadClient(tenant_id=tenant_c, allow_multitenant_authentication=True, **common_args) + assert client_c.get_cached_access_token([scope]) is None + token = client_c.get_cached_access_token([scope], tenant_id=tenant_a) + assert token.token == expected_token + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + assert client_c.get_cached_access_token([scope], tenant_id=tenant_a) is None diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index 81d3fc2d827f..64212d573e97 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -5,13 +5,11 @@ import functools from unittest.mock import Mock, patch from urllib.parse import urlparse -import time from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError -from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY +from azure.identity._constants import EnvironmentVariables from azure.identity._internal import AadClientCertificate from azure.identity.aio._internal.aad_client import AadClient -from azure.core.credentials import AccessToken from msal import TokenCache import pytest @@ -241,3 +239,76 @@ async def test_retries_token_requests(): await client.obtain_token_by_refresh_token("", "") assert transport.send.call_count > 1 transport.send.reset_mock() + + +async def test_shared_cache(): + """The client should return only tokens associated with its own client_id""" + + client_id_a = "client-id-a" + client_id_b = "client-id-b" + scope = "scope" + expected_token = "***" + tenant_id = "tenant" + authority = "https://localhost/" + tenant_id + + cache = TokenCache() + cache.add( + { + "response": build_aad_response(access_token=expected_token), + "client_id": client_id_a, + "scope": [scope], + "token_endpoint": "/".join((authority, tenant_id, "oauth2/v2.0/token")), + } + ) + + common_args = dict(authority=authority, cache=cache, tenant_id=tenant_id) + client_a = AadClient(client_id=client_id_a, **common_args) + client_b = AadClient(client_id=client_id_b, **common_args) + + # A has a cached token + token = client_a.get_cached_access_token([scope]) + assert token.token == expected_token + + # which B shouldn't return + assert client_b.get_cached_access_token([scope]) is None + + +async def test_multitenant_cache(): + client_id = "client-id" + scope = "scope" + expected_token = "***" + tenant_a = "tenant-a" + tenant_b = "tenant-b" + tenant_c = "tenant-c" + authority = "https://localhost/" + tenant_a + + cache = TokenCache() + cache.add( + { + "response": build_aad_response(access_token=expected_token), + "client_id": client_id, + "scope": [scope], + "token_endpoint": "/".join((authority, tenant_a, "oauth2/v2.0/token")), + } + ) + + common_args = dict(authority=authority, cache=cache, client_id=client_id) + client_a = AadClient(tenant_id=tenant_a, **common_args) + client_b = AadClient(tenant_id=tenant_b, **common_args) + + # A has a cached token + token = client_a.get_cached_access_token([scope]) + assert token.token == expected_token + + # which B shouldn't return + assert client_b.get_cached_access_token([scope]) is None + + # but C allows multitenant auth and should therefore return the token from tenant_a when appropriate + client_c = AadClient(tenant_id=tenant_c, allow_multitenant_authentication=True, **common_args) + assert client_c.get_cached_access_token([scope]) is None + token = client_c.get_cached_access_token([scope], tenant_id=tenant_a) + assert token.token == expected_token + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + assert client_c.get_cached_access_token([scope], tenant_id=tenant_a) is None diff --git a/sdk/identity/azure-identity/tests/test_auth_code.py b/sdk/identity/azure-identity/tests/test_auth_code.py index 7b6ce76a75ed..a7baf083f0a2 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code.py +++ b/sdk/identity/azure-identity/tests/test_auth_code.py @@ -2,19 +2,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity import AuthorizationCodeCredential +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT import msal import pytest +from six.moves.urllib_parse import urlparse from helpers import build_aad_response, mock_response, Request, validating_transport try: - from unittest.mock import Mock + from unittest.mock import Mock, patch except ImportError: # python < 3.3 - from mock import Mock # type: ignore + from mock import Mock, patch # type: ignore def test_no_scopes(): @@ -114,3 +116,75 @@ def test_auth_code_credential(): token = credential.get_token(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 2 + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**")) + + credential = AuthorizationCodeCredential( + first_tenant, + "client-id", + "authcode", + "https://localhost", + allow_multitenant_authentication=True, + transport=Mock(send=send), + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_backcompat(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**")) + + credential = AuthorizationCodeCredential( + expected_tenant, "client-id", "authcode", "https://localhost", transport=Mock(send=send) + ) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_auth_code_async.py b/sdk/identity/azure-identity/tests/test_auth_code_async.py index 3b754b40a6f4..d7c6e81646f0 100644 --- a/sdk/identity/azure-identity/tests/test_auth_code_async.py +++ b/sdk/identity/azure-identity/tests/test_auth_code_async.py @@ -2,9 +2,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from unittest.mock import Mock +from unittest.mock import Mock, patch +from urllib.parse import urlparse +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy +from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT from azure.identity.aio import AuthorizationCodeCredential import msal @@ -137,3 +140,75 @@ async def test_auth_code_credential(): token = await credential.get_token(expected_scope) assert token.token == expected_access_token assert transport.send.call_count == 2 + + +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**")) + + credential = AuthorizationCodeCredential( + first_tenant, + "client-id", + "authcode", + "https://localhost", + allow_multitenant_authentication=True, + transport=Mock(send=send), + ) + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +async def test_multitenant_authentication_backcompat(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token, refresh_token="**")) + + credential = AuthorizationCodeCredential( + expected_tenant, "client-id", "authcode", "https://localhost", transport=Mock(send=send) + ) + + token = await credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = await credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index 666cec214860..38272fdb47ec 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -5,6 +5,7 @@ import json import os +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import CertificateCredential, RegionalAuthority, TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -343,3 +344,83 @@ def test_certificate_arguments(): CertificateCredential("tenant-id", "client-id", certificate_path="...", certificate_data="...") message = str(ex.value) assert "certificate_data" in message and "certificate_path" in message + + +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +def test_allow_multitenant_authentication(cert_path, cert_password): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant, "common"), 'unexpected tenant "{}"'.format(tenant) + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = CertificateCredential( + first_tenant, + "client-id", + cert_path, + password=cert_password, + allow_multitenant_authentication=True, + transport=Mock(send=send), + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +def test_multitenant_authentication_backcompat(cert_path, cert_password): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def send(request, **_): + parsed = urlparse(request.url) + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, expected_tenant)) + + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = CertificateCredential( + expected_tenant, "client-id", cert_path, password=cert_password, transport=Mock(send=send) + ) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict( + os.environ, {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py index 001308460ead..fbfaa562f157 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, patch from urllib.parse import urlparse +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -14,7 +15,7 @@ from msal import TokenCache import pytest -from helpers import build_aad_response, urlsafeb64_decode, mock_response, Request +from helpers import build_aad_response, mock_response, Request from helpers_async import async_validating_transport, AsyncMockTransport from test_certificate_credential import BOTH_CERTS, CERT_PATH, EC_CERT_PATH, validate_jwt @@ -265,3 +266,79 @@ def test_certificate_arguments(): CertificateCredential("tenant-id", "client-id", certificate_path="...", certificate_data="...") message = str(ex.value) assert "certificate_data" in message and "certificate_path" in message + + +@pytest.mark.asyncio +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +async def test_allow_multitenant_authentication(cert_path, cert_password): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = CertificateCredential( + first_tenant, + "client-id", + cert_path, + password=cert_password, + allow_multitenant_authentication=True, + transport=Mock(send=send), + ) + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +@pytest.mark.asyncio +@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) +async def test_multitenant_authentication_backcompat(cert_path, cert_password): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = CertificateCredential( + expected_tenant, "client-id", cert_path, password=cert_password, transport=Mock(send=send) + ) + + token = await credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = await credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_cli_credential.py b/sdk/identity/azure-identity/tests/test_cli_credential.py index eb0bb106125a..fa15a8eddc7d 100644 --- a/sdk/identity/azure-identity/tests/test_cli_credential.py +++ b/sdk/identity/azure-identity/tests/test_cli_credential.py @@ -4,9 +4,11 @@ # ------------------------------------ from datetime import datetime import json +import regex import sys from azure.identity import AzureCliCredential, CredentialUnavailableError +from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.azure_cli import CLI_NOT_FOUND, NOT_LOGGED_IN from azure.core.exceptions import ClientAuthenticationError @@ -148,3 +150,79 @@ def test_timeout(): with mock.patch(CHECK_OUTPUT, mock.Mock(side_effect=TimeoutExpired("", 42))): with pytest.raises(CredentialUnavailableError): AzureCliCredential().get_token("scope") + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + default_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def fake_check_output(command_line, **_): + match = regex.search("--tenant (.*)", command_line[-1]) + tenant = match[1] if match else default_tenant + assert tenant in (default_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + return json.dumps( + { + "expiresOn": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + "accessToken": first_token if tenant == default_tenant else second_token, + "subscription": "some-guid", + "tenant": tenant, + "tokenType": "Bearer", + } + ) + + credential = AzureCliCredential(allow_multitenant_authentication=True) + with mock.patch(CHECK_OUTPUT, fake_check_output): + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=default_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_backcompat(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def fake_check_output(command_line, **_): + match = regex.search("--tenant (.*)", command_line[-1]) + assert match is None or match[1] == expected_tenant + return json.dumps( + { + "expiresOn": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + "accessToken": expected_token, + "subscription": "some-guid", + "tenant": expected_token, + "tokenType": "Bearer", + } + ) + + credential = AzureCliCredential() + with mock.patch(CHECK_OUTPUT, fake_check_output): + token = credential.get_token("scope") + assert token.token == expected_token + + # specifying a tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with mock.patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert ( + token.token == expected_token + ), "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_cli_credential_async.py b/sdk/identity/azure-identity/tests/test_cli_credential_async.py index 69b1f8c1d41f..e291fdc7c267 100644 --- a/sdk/identity/azure-identity/tests/test_cli_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_cli_credential_async.py @@ -5,11 +5,13 @@ import asyncio from datetime import datetime import json +import regex import sys from unittest import mock from azure.identity import CredentialUnavailableError from azure.identity.aio import AzureCliCredential +from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.azure_cli import CLI_NOT_FOUND, NOT_LOGGED_IN from azure.core.exceptions import ClientAuthenticationError import pytest @@ -181,3 +183,81 @@ async def test_timeout(): assert proc.communicate.call_count == 1 assert proc.kill.call_count == 1 + + +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + default_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def fake_exec(*args, **_): + match = regex.search("--tenant (.*)", args[-1]) + tenant = match[1] if match else default_tenant + assert tenant in (default_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + output = json.dumps( + { + "expiresOn": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + "accessToken": first_token if tenant == default_tenant else second_token, + "subscription": "some-guid", + "tenant": tenant, + "tokenType": "Bearer", + } + ).encode() + return mock.Mock(communicate=mock.Mock(return_value=get_completed_future((output, b""))), returncode=0) + + credential = AzureCliCredential(allow_multitenant_authentication=True) + with mock.patch(SUBPROCESS_EXEC, fake_exec): + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=default_tenant) + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +async def test_multitenant_authentication_backcompat(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def fake_exec(*args, **_): + match = regex.search("--tenant (.*)", args[-1]) + assert match is None or match[1] == expected_tenant + output = json.dumps( + { + "expiresOn": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + "accessToken": expected_token, + "subscription": "some-guid", + "tenant": expected_token, + "tokenType": "Bearer", + } + ).encode() + return mock.Mock(communicate=mock.Mock(return_value=get_completed_future((output, b""))), returncode=0) + + credential = AzureCliCredential() + with mock.patch(SUBPROCESS_EXEC, fake_exec): + token = await credential.get_token("scope") + assert token.token == expected_token + + # specifying a tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with mock.patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert ( + token.token == expected_token + ), "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential.py b/sdk/identity/azure-identity/tests/test_client_secret_credential.py index ded3c9727e1d..dc4ff4af3d80 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import ClientSecretCredential, RegionalAuthority, TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -10,7 +11,7 @@ import pytest from six.moves.urllib_parse import urlparse -from helpers import build_aad_response, mock_response, msal_validating_transport, Request +from helpers import build_aad_response, get_discovery_response, mock_response, msal_validating_transport, Request try: from unittest.mock import Mock, patch @@ -208,3 +209,74 @@ def test_cache_multiple_clients(): assert transport_b.send.call_count == 3 assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2 + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant, "common"), 'unexpected tenant "{}"'.format(tenant) + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = ClientSecretCredential( + first_tenant, "client-id", "secret", allow_multitenant_authentication=True, transport=Mock(send=send) + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_backcompat(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def send(request, **_): + parsed = urlparse(request.url) + if "/oauth2/v2.0/token" not in parsed.path: + return get_discovery_response("https://{}/{}".format(parsed.netloc, expected_tenant)) + + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = ClientSecretCredential(expected_tenant, "client-id", "secret", transport=Mock(send=send)) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py index 96d4366828d6..8f1c85315b35 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import ContentDecodePolicy, SansIOHTTPPolicy from azure.identity import TokenCachePersistenceOptions from azure.identity._constants import EnvironmentVariables @@ -247,3 +248,70 @@ async def test_cache_multiple_clients(): assert transport_b.send.call_count == 1 assert len(cache.find(TokenCache.CredentialType.ACCESS_TOKEN)) == 2 + + +@pytest.mark.asyncio +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = ClientSecretCredential( + first_tenant, "client-id", "secret", allow_multitenant_authentication=True, transport=Mock(send=send) + ) + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +@pytest.mark.asyncio +async def test_multitenant_authentication_backcompat(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = ClientSecretCredential(expected_tenant, "client-id", "secret", transport=Mock(send=send)) + + token = await credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = await credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_interactive_credential.py b/sdk/identity/azure-identity/tests/test_interactive_credential.py index d3a0c8f52881..9537d4ae6639 100644 --- a/sdk/identity/azure-identity/tests/test_interactive_credential.py +++ b/sdk/identity/azure-identity/tests/test_interactive_credential.py @@ -10,15 +10,16 @@ CredentialUnavailableError, TokenCachePersistenceOptions, ) -from azure.identity._internal import InteractiveCredential +from azure.identity._internal import EnvironmentVariables, InteractiveCredential import pytest +from six.moves.urllib_parse import urlparse try: from unittest.mock import Mock, patch except ImportError: # python < 3.3 from mock import Mock, patch # type: ignore -from helpers import build_aad_response, build_id_token, id_token_claims +from helpers import build_aad_response, get_discovery_response, id_token_claims # fake object for tests which need to exercise request_token but don't care about its return value @@ -41,24 +42,14 @@ class MockCredential(InteractiveCredential): Default instances have an empty in-memory cache, and raise rather than send an HTTP request. """ - def __init__( - self, client_id="...", request_token=None, msal_app_factory=None, transport=None, **kwargs - ): - self._msal_app_factory = msal_app_factory + def __init__(self, client_id="...", request_token=None, transport=None, **kwargs): self._request_token_impl = request_token or Mock() transport = transport or Mock(send=Mock(side_effect=Exception("credential shouldn't send a request"))) - super(MockCredential, self).__init__( - client_id=client_id, transport=transport, **kwargs - ) + super(MockCredential, self).__init__(client_id=client_id, transport=transport, **kwargs) def _request_token(self, *scopes, **kwargs): return self._request_token_impl(*scopes, **kwargs) - def _get_app(self): - if self._msal_app_factory: - return self._create_app(self._msal_app_factory) - return super(MockCredential, self)._get_app() - def test_no_scopes(): """The credential should raise when get_token is called with no scopes""" @@ -79,14 +70,13 @@ def validate_app_parameters(authority, client_id, **_): assert client_id == record.client_id return Mock(get_accounts=Mock(return_value=[])) - app_factory = Mock(wraps=validate_app_parameters) - credential = MockCredential( - authentication_record=record, disable_automatic_authentication=True, msal_app_factory=app_factory, - ) + mock_client_application = Mock(wraps=validate_app_parameters) + credential = MockCredential(authentication_record=record, disable_automatic_authentication=True) with pytest.raises(AuthenticationRequiredError): - credential.get_token("scope") + with patch("msal.PublicClientApplication", mock_client_application): + credential.get_token("scope") - assert app_factory.call_count == 1, "credential didn't create an msal application" + assert mock_client_application.call_count == 1, "credential didn't create an msal application" def test_tenant_argument_overrides_record(): @@ -104,13 +94,11 @@ def validate_authority(authority, **_): return Mock(get_accounts=Mock(return_value=[])) credential = MockCredential( - authentication_record=record, - tenant_id=expected_tenant, - disable_automatic_authentication=True, - msal_app_factory=validate_authority, + authentication_record=record, tenant_id=expected_tenant, disable_automatic_authentication=True ) with pytest.raises(AuthenticationRequiredError): - credential.get_token("scope") + with patch("msal.PublicClientApplication", validate_authority): + credential.get_token("scope") def test_disable_automatic_authentication(): @@ -126,14 +114,14 @@ def test_disable_automatic_authentication(): credential = MockCredential( authentication_record=record, disable_automatic_authentication=True, - msal_app_factory=lambda *_, **__: msal_app, request_token=Mock(side_effect=Exception("credential shouldn't begin interactive authentication")), ) scope = "scope" expected_claims = "..." with pytest.raises(AuthenticationRequiredError) as ex: - credential.get_token(scope, claims=expected_claims) + with patch("msal.PublicClientApplication", lambda *_, **__: msal_app): + credential.get_token(scope, claims=expected_claims) # the exception should carry the requested scopes and claims, and any error message from AAD assert ex.value.scopes == (scope,) @@ -208,9 +196,10 @@ class CustomException(Exception): acquire_token_silent_with_error=Mock(side_effect=CustomException(expected_message)), get_accounts=Mock(return_value=[{"home_account_id": record.home_account_id}]), ) - credential = MockCredential(msal_app_factory=lambda *_, **__: msal_app, authentication_record=record) + credential = MockCredential(authentication_record=record) with pytest.raises(ClientAuthenticationError) as ex: - credential.get_token("scope") + with patch("msal.PublicClientApplication", lambda *_, **__: msal_app): + credential.get_token("scope") assert expected_message in ex.value.message assert msal_app.acquire_token_silent_with_error.call_count == 1, "credential didn't attempt silent auth" @@ -291,3 +280,96 @@ def _request_token(self, *_, **__): assert record.home_account_id == subject assert record.tenant_id == tenant assert record.username == username + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def request_token(*_, tenant_id=None, **__): + return build_aad_response( + access_token=second_token if tenant_id == second_tenant else first_token, + id_token_claims=id_token_claims( + aud="...", + iss="http://localhost/tenant", + sub="subject", + preferred_username="...", + tenant_id="...", + object_id="...", + ), + ) + + def send(request, **_): + assert "/oauth2/v2.0/token" not in request.url, 'mock "request_token" should prevent sending a token request' + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + credential = MockCredential( + tenant_id=first_tenant, + allow_multitenant_authentication=True, + request_token=request_token, + transport=Mock(send=send), + ) + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_backcompat(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def request_token(*_, **__): + return build_aad_response( + access_token=expected_token, + id_token_claims=id_token_claims( + aud="...", + iss="http://localhost/tenant", + sub="subject", + preferred_username="...", + tenant_id="...", + object_id="...", + ), + ) + + def send(request, **_): + assert "/oauth2/v2.0/token" not in request.url, 'mock "request_token" should prevent sending a token request' + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + return get_discovery_response("https://{}/{}".format(parsed.netloc, tenant)) + + credential = MockCredential(tenant_id=expected_tenant, transport=Mock(send=send), request_token=request_token) + + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index 0675d8da2547..6367daac324c 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -6,6 +6,7 @@ import time from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError from azure.identity import AzureAuthorityHosts, CredentialUnavailableError, VisualStudioCodeCredential from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.identity._constants import EnvironmentVariables @@ -265,3 +266,71 @@ def test_no_user_settings(): credential.get_token("scope") assert transport.send.call_count == 1 + + +def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = get_credential( + tenant_id=first_tenant, allow_multitenant_authentication=True, transport=mock.Mock(send=send) + ) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = credential.get_token("scope") + assert token.token == first_token + + token = credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = credential.get_token("scope") + assert token.token == first_token + + +def test_multitenant_authentication_backcompat(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = get_credential(tenant_id=expected_tenant, transport=mock.Mock(send=send)) + + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with mock.patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled" diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py index 6afeb58f655a..4d1da0e21bf7 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse from azure.core.credentials import AccessToken +from azure.core.exceptions import ClientAuthenticationError from azure.identity import AzureAuthorityHosts, CredentialUnavailableError from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT @@ -258,3 +259,74 @@ async def test_no_user_settings(): await credential.get_token("scope") assert transport.send.call_count == 1 + + + +@pytest.mark.asyncio +async def test_allow_multitenant_authentication(): + """When allow_multitenant_authentication is True, the credential should respect get_token(tenant_id=...)""" + + first_tenant = "first-tenant" + first_token = "***" + second_tenant = "second-tenant" + second_token = first_token * 2 + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + assert tenant in (first_tenant, second_tenant), 'unexpected tenant "{}"'.format(tenant) + token = first_token if tenant == first_tenant else second_token + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = get_credential( + tenant_id=first_tenant, allow_multitenant_authentication=True, transport=mock.Mock(send=send) + ) + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = await credential.get_token("scope") + assert token.token == first_token + + token = await credential.get_token("scope", tenant_id=first_tenant) + assert token.token == first_token + + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = await credential.get_token("scope", tenant_id=second_tenant) + assert token.token == second_token + + # should still default to the first tenant + token = await credential.get_token("scope") + assert token.token == first_token + + +@pytest.mark.asyncio +async def test_multitenant_authentication_backcompat(): + """get_token(tenant_id=...) should raise when allow_multitenant_authentication is False (the default)""" + + expected_tenant = "expected-tenant" + expected_token = "***" + + async def send(request, **_): + parsed = urlparse(request.url) + tenant = parsed.path.split("/")[1] + token = expected_token if tenant == expected_tenant else expected_token * 2 + return mock_response(json_payload=build_aad_response(access_token=token)) + + credential = get_credential(tenant_id=expected_tenant, transport=mock.Mock(send=send)) + + with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"): + token = await credential.get_token("scope") + assert token.token == expected_token + + # explicitly specifying the configured tenant is okay + token = await credential.get_token("scope", tenant_id=expected_tenant) + assert token.token == expected_token + + # but any other tenant should get an error + with pytest.raises(ClientAuthenticationError, match="allow_multitenant_authentication"): + await credential.get_token("scope", tenant_id="un" + expected_tenant) + + # ...unless the compat switch is enabled + with mock.patch.dict( + "os.environ", {EnvironmentVariables.AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION: "true"}, clear=True + ): + token = await credential.get_token("scope", tenant_id="un" + expected_tenant) + assert token.token == expected_token, "credential should ignore tenant_id kwarg when the compat switch is enabled"