Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

### Other Changes

- When `AZURE_TOKEN_CREDENTIALS` is set to `ManagedIdentityCredential`, `DefaultAzureCredential` now skips the IMDS endpoint probe request and directly attempts token acquisition with full retry logic, matching the behavior of using `ManagedIdentityCredential` standalone. ([#43080](https://github.com/Azure/azure-sdk-for-python/pull/43080))
- Improved error messages from `ManagedIdentityCredential` to include the full error response from managed identity endpoints for better troubleshooting. ([#43231](https://github.com/Azure/azure-sdk-for-python/pull/43231))

## 1.25.0 (2025-09-11)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement

process_timeout = kwargs.pop("process_timeout", 10)
require_envvar = kwargs.pop("require_envvar", False)
if require_envvar and not os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS):
token_credentials_env = os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS, "").strip().lower()
if require_envvar and not token_credentials_env:
raise ValueError(
"AZURE_TOKEN_CREDENTIALS environment variable is required but is not set or is empty. "
"Set it to 'dev', 'prod', or a specific credential name."
Expand Down Expand Up @@ -274,6 +275,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
ManagedIdentityCredential(
client_id=managed_identity_client_id,
_exclude_workload_identity_credential=exclude_workload_identity_credential,
_enable_imds_probe=token_credentials_env != "managedidentitycredential",
**kwargs,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _check_forbidden_response(ex: HttpResponseError) -> None:

class ImdsCredential(MsalManagedIdentityClient):
def __init__(self, **kwargs: Any) -> None:
# If set to True/False, _enable_imds_probe forces whether or not the credential
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
# the credential probes only if it's part of a ChainedTokenCredential chain.
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
super().__init__(retry_policy_class=ImdsRetryPolicy, **dict(PIPELINE_SETTINGS, **kwargs))
self._config = kwargs

Expand All @@ -102,9 +106,9 @@ def close(self) -> None:

def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:

if within_credential_chain.get() and not self._endpoint_available:
# If within a chain (e.g. DefaultAzureCredential), we do a quick check to see if the IMDS endpoint
# is available to avoid hanging for a long time if the endpoint isn't available.
do_probe = self._enable_imds_probe if self._enable_imds_probe is not None else within_credential_chain.get()
if do_probe and not self._endpoint_available:
# Probe to see if the IMDS endpoint is available to avoid hanging for a long time if it's not.
try:
client = ManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **self._config))
client.request_token(*scopes, connection_timeout=1, retry_total=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
user_identity_info = validate_identity_config(client_id, identity_config)
self._credential: Optional[SupportsTokenInfo] = None
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
managed_identity_type = None

if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
Expand Down Expand Up @@ -136,7 +137,12 @@ def __init__(
managed_identity_type = "IMDS"
from .imds import ImdsCredential

self._credential = ImdsCredential(client_id=client_id, identity_config=identity_config, **kwargs)
self._credential = ImdsCredential(
client_id=client_id,
identity_config=identity_config,
_enable_imds_probe=self._enable_imds_probe,
**kwargs,
)

if managed_identity_type:
log_msg = f"{self.__class__.__name__} will use {managed_identity_type}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement

process_timeout = kwargs.pop("process_timeout", 10)
require_envvar = kwargs.pop("require_envvar", False)
if require_envvar and not os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS):
token_credentials_env = os.environ.get(EnvironmentVariables.AZURE_TOKEN_CREDENTIALS, "").strip().lower()
if require_envvar and not token_credentials_env:
raise ValueError(
"AZURE_TOKEN_CREDENTIALS environment variable is required but is not set or is empty. "
"Set it to 'dev', 'prod', or a specific credential name."
Expand Down Expand Up @@ -235,6 +236,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=too-many-statement
ManagedIdentityCredential(
client_id=managed_identity_client_id,
_exclude_workload_identity_credential=exclude_workload_identity_credential,
_enable_imds_probe=token_credentials_env != "managedidentitycredential",
**kwargs,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class ImdsCredential(AsyncContextManager, GetTokenMixin):
def __init__(self, **kwargs: Any) -> None:
super().__init__()

# If set to True/False, _enable_imds_probe forces whether or not the credential
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
# the credential probes only if it's part of a ChainedTokenCredential chain.
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
kwargs["retry_policy_class"] = AsyncImdsRetryPolicy
self._client = AsyncManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **kwargs))
if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ:
Expand All @@ -65,9 +69,9 @@ async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional

async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:

if within_credential_chain.get() and not self._endpoint_available:
# If within a chain (e.g. DefaultAzureCredential), we do a quick check to see if the IMDS endpoint
# is available to avoid hanging for a long time if the endpoint isn't available.
do_probe = self._enable_imds_probe if self._enable_imds_probe is not None else within_credential_chain.get()
if do_probe and not self._endpoint_available:
# Probe to see if the IMDS endpoint is available to avoid hanging for a long time if it's not.
try:
await self._client.request_token(*scopes, connection_timeout=1, retry_total=0)
self._endpoint_available = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
user_identity_info = validate_identity_config(client_id, identity_config)
self._credential: Optional[AsyncSupportsTokenInfo] = None
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
managed_identity_type = None
if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
if os.environ.get(EnvironmentVariables.IDENTITY_HEADER):
Expand Down Expand Up @@ -108,7 +109,12 @@ def __init__(
managed_identity_type = "IMDS"
from .imds import ImdsCredential

self._credential = ImdsCredential(client_id=client_id, identity_config=identity_config, **kwargs)
self._credential = ImdsCredential(
client_id=client_id,
identity_config=identity_config,
_enable_imds_probe=self._enable_imds_probe,
**kwargs,
)

if managed_identity_type:
log_msg = f"{self.__class__.__name__} will use {managed_identity_type}"
Expand Down
6 changes: 5 additions & 1 deletion sdk/identity/azure-identity/tests/test_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,11 @@ def test_default_credential_shared_cache_use(mock_credential):
def test_managed_identity_client_id():
"""the credential should accept a user-assigned managed identity's client ID by kwarg or environment variable"""

expected_args = {"client_id": "the-client", "_exclude_workload_identity_credential": False}
expected_args = {
"client_id": "the-client",
"_exclude_workload_identity_credential": False,
"_enable_imds_probe": True,
}

with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
DefaultAzureCredential(managed_identity_client_id=expected_args["client_id"])
Expand Down
6 changes: 5 additions & 1 deletion sdk/identity/azure-identity/tests/test_default_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,11 @@ async def test_default_credential_shared_cache_use():
def test_managed_identity_client_id():
"""the credential should accept a user-assigned managed identity's client ID by kwarg or environment variable"""

expected_args = {"client_id": "the-client", "_exclude_workload_identity_credential": False}
expected_args = {
"client_id": "the-client",
"_exclude_workload_identity_credential": False,
"_enable_imds_probe": True,
}

with patch(DefaultAzureCredential.__module__ + ".ManagedIdentityCredential") as mock_credential:
DefaultAzureCredential(managed_identity_client_id=expected_args["client_id"])
Expand Down
7 changes: 2 additions & 5 deletions sdk/identity/azure-identity/tests/test_imds_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
IMDS_AUTHORITY,
PIPELINE_SETTINGS,
)
from azure.identity._internal.utils import within_credential_chain
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.policies import RetryPolicy
from azure.core.pipeline.transport import HttpRequest, HttpResponse
Expand Down Expand Up @@ -109,7 +108,7 @@ def test_user_assigned_tenant_id(self, recorded_test, get_token_method):
assert isinstance(token.expires_on, int)

@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_managed_identity_aci_probe(self, get_token_method):
def test_enable_imds_probe(self, get_token_method):
access_token = "****"
expires_on = 42
expected_token = access_token
Expand Down Expand Up @@ -140,11 +139,9 @@ def test_managed_identity_aci_probe(self, get_token_method):
),
],
)
within_credential_chain.set(True)
credential = ImdsCredential(transport=transport)
credential = ImdsCredential(transport=transport, _enable_imds_probe=True)
token = getattr(credential, get_token_method)(scope)
assert token.token == expected_token
within_credential_chain.set(False)

def test_imds_credential_uses_custom_retry_policy(self):
credential = ImdsCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.policies import AsyncRetryPolicy
from azure.core.pipeline.transport import HttpRequest, HttpResponse
from azure.identity._internal.utils import within_credential_chain
import pytest

from helpers import mock_response, Request, GET_TOKEN_METHODS
Expand Down Expand Up @@ -316,7 +315,7 @@ async def test_user_assigned_tenant_id(self, recorded_test, get_token_method):

@pytest.mark.asyncio
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_managed_identity_aci_probe(self, get_token_method):
async def test_enable_imds_probe(self, get_token_method):
access_token = "****"
expires_on = 42
expected_token = access_token
Expand Down Expand Up @@ -346,11 +345,9 @@ async def test_managed_identity_aci_probe(self, get_token_method):
),
],
)
within_credential_chain.set(True)
credential = ImdsCredential(transport=transport)
credential = ImdsCredential(transport=transport, _enable_imds_probe=True)
token = await getattr(credential, get_token_method)(scope)
assert token.token == expected_token
within_credential_chain.set(False)

async def test_imds_credential_uses_custom_retry_policy(self):
credential = ImdsCredential()
Expand Down
76 changes: 76 additions & 0 deletions sdk/identity/azure-identity/tests/test_managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,82 @@ def test_imds_tenant_id(get_token_method):
assert token.token == expected_token


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_enable_imds_probe(get_token_method):
access_token = "****"
expires_on = 42
expected_token = access_token
scope = "scope"
transport = validating_transport(
requests=[
Request(base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH),
Request(
base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH,
method="GET",
required_headers={"Metadata": "true"},
required_params={"resource": scope},
),
],
responses=[
# probe receives error response
mock_response(status_code=400),
mock_response(
json_payload={
"access_token": access_token,
"expires_in": 42,
"expires_on": expires_on,
"ext_expires_in": 42,
"not_before": int(time.time()),
"resource": scope,
"token_type": "Bearer",
}
),
],
)
credential = ManagedIdentityCredential(transport=transport, _enable_imds_probe=True)
token = getattr(credential, get_token_method)(scope)
assert token.token == expected_token


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_imds_probe_enabled_in_chain(get_token_method):
access_token = "****"
expires_on = 42
expected_token = access_token
scope = "scope"
transport = validating_transport(
requests=[
Request(base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH),
Request(
base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH,
method="GET",
required_headers={"Metadata": "true"},
required_params={"resource": scope},
),
],
responses=[
# probe receives error response
mock_response(status_code=400),
mock_response(
json_payload={
"access_token": access_token,
"expires_in": 42,
"expires_on": expires_on,
"ext_expires_in": 42,
"not_before": int(time.time()),
"resource": scope,
"token_type": "Bearer",
}
),
],
)
credential = ManagedIdentityCredential(transport=transport)
within_credential_chain.set(True)
token = getattr(credential, get_token_method)(scope)
assert token.token == expected_token
within_credential_chain.set(False)


@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
def test_imds_text_response(get_token_method):
within_credential_chain.set(True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,82 @@ async def test_imds_user_assigned_identity(get_token_method):
assert token.token == expected_token


@pytest.mark.asyncio
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_enable_imds_probe(get_token_method):
access_token = "****"
expires_on = 42
expected_token = access_token
scope = "scope"
transport = async_validating_transport(
requests=[
Request(base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH),
Request(
base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH,
method="GET",
required_headers={"Metadata": "true"},
required_params={"resource": scope},
),
],
responses=[
mock_response(status_code=400),
mock_response(
json_payload={
"access_token": access_token,
"expires_in": 42,
"expires_on": expires_on,
"ext_expires_in": 42,
"not_before": int(time.time()),
"resource": scope,
"token_type": "Bearer",
}
),
],
)
credential = ManagedIdentityCredential(transport=transport, _enable_imds_probe=True)
token = await getattr(credential, get_token_method)(scope)
assert token.token == expected_token


@pytest.mark.asyncio
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_imds_probe_enabled_in_chain(get_token_method):
access_token = "****"
expires_on = 42
expected_token = access_token
scope = "scope"
transport = async_validating_transport(
requests=[
Request(base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH),
Request(
base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH,
method="GET",
required_headers={"Metadata": "true"},
required_params={"resource": scope},
),
],
responses=[
mock_response(status_code=400),
mock_response(
json_payload={
"access_token": access_token,
"expires_in": 42,
"expires_on": expires_on,
"ext_expires_in": 42,
"not_before": int(time.time()),
"resource": scope,
"token_type": "Bearer",
}
),
],
)
within_credential_chain.set(True)
credential = ManagedIdentityCredential(transport=transport)
token = await getattr(credential, get_token_method)(scope)
assert token.token == expected_token
within_credential_chain.set(False)


@pytest.mark.asyncio
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
async def test_imds_text_response(get_token_method):
Expand Down
Loading