Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Identity] Managed identity bug fix #36010

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
# Release History

## 1.17.0 (2024-06-11)
## 1.17.0b2 (2024-06-11)

### Features Added

- `OnBehalfOfCredential` now supports client assertion callbacks through the `client_assertion_func` keyword argument. This enables authenticating with client assertions such as federated credentials. ([#35812](https://github.com/Azure/azure-sdk-for-python/pull/35812))

### Bugs Fixed

- Managed identity bug fixes

## 1.16.1 (2024-06-11)

### Bugs Fixed

- Managed identity bug fixes

## 1.17.0b1 (2024-05-13)

### Features Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# ------------------------------------
import functools
import os
import sys
from typing import Any, Dict, Optional

from azure.core.exceptions import ClientAuthenticationError
Expand All @@ -24,7 +25,7 @@ def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]:
return ManagedIdentityClient(
_per_retry_policies=[ArcChallengeAuthPolicy()],
request_factory=functools.partial(_get_request, url),
**kwargs
**kwargs,
)
return None

Expand Down Expand Up @@ -70,6 +71,12 @@ def _get_secret_key(response: PipelineResponse) -> str:
raise ClientAuthenticationError(
message="Did not receive a correct value from WWW-Authenticate header: {}".format(header)
) from ex

try:
_validate_key_file(key_file)
except ValueError as ex:
raise ClientAuthenticationError(message="The key file path is invalid: {}".format(ex)) from ex

with open(key_file, "r", encoding="utf-8") as file:
try:
return file.read()
Expand All @@ -80,6 +87,53 @@ def _get_secret_key(response: PipelineResponse) -> str:
) from error


def _get_key_file_path() -> str:
"""Returns the expected path for the Azure Arc MSI key file based on the current platform.

Only Linux and Windows are supported.

:return: The expected path.
:rtype: str
:raises ValueError: If the current platform is not supported.
"""
if sys.platform.startswith("linux"):
return "/var/opt/azcmagent/tokens"
if sys.platform.startswith("win"):
program_data_path = os.environ.get("PROGRAMDATA")
if not program_data_path:
raise ValueError("PROGRAMDATA environment variable is not set or is empty.")
return os.path.join(f"{program_data_path}", "AzureConnectedMachineAgent", "Tokens")
raise ValueError(f"Azure Arc MSI is not supported on this platform {sys.platform}")


def _validate_key_file(file_path: str) -> None:
"""Validates that a given Azure Arc MSI file path is valid for use.

A valid file will:
1. Be in the expected path for the current platform.
2. Have a `.key` extension.
3. Be at most 4096 bytes in size.

:param str file_path: The path to the key file.
:raises ClientAuthenticationError: If the file path is invalid.
"""
if not file_path:
raise ValueError("The file path must not be empty.")

if not os.path.exists(file_path):
raise ValueError(f"The file path does not exist: {file_path}")

expected_directory = _get_key_file_path()
if not os.path.dirname(file_path) == expected_directory:
raise ValueError(f"Unexpected file path from HIMDS service: {file_path}")

if not file_path.endswith(".key"):
raise ValueError("The file path must have a '.key' extension.")

if os.path.getsize(file_path) > 4096:
raise ValueError("The file size must be less than or equal to 4096 bytes.")


class ArcChallengeAuthPolicy(HTTPPolicy):
"""Policy for handling Azure Arc's challenge authentication"""

Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/azure/identity/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
VERSION = "1.17.0"
VERSION = "1.17.0b2"
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/identity/azure-identity",
keywords="azure, azure sdk",
classifiers=[
"Development Status :: 5 - Production/Stable",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3",
Expand Down
132 changes: 124 additions & 8 deletions sdk/identity/azure-identity/tests/test_managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import os
import sys
import time

try:
Expand Down Expand Up @@ -883,9 +884,10 @@ def test_azure_arc(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
token = ManagedIdentityCredential(transport=transport).get_token(scope)
assert token.token == access_token
assert token.expires_on == expires_on
with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
token = ManagedIdentityCredential(transport=transport).get_token(scope)
assert token.token == access_token
assert token.expires_on == expires_on


def test_azure_arc_tenant_id(tmpdir):
Expand Down Expand Up @@ -936,9 +938,10 @@ def test_azure_arc_tenant_id(tmpdir):
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
assert token.token == access_token
assert token.expires_on == expires_on
with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
assert token.token == access_token
assert token.expires_on == expires_on


def test_azure_arc_client_id():
Expand All @@ -950,10 +953,123 @@ def test_azure_arc_client_id():
EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42",
},
):
credential = ManagedIdentityCredential(client_id="some-guid")
with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
credential = ManagedIdentityCredential(client_id="some-guid")

with pytest.raises(ClientAuthenticationError):
with pytest.raises(ClientAuthenticationError) as ex:
credential.get_token("scope")
assert "not supported" in str(ex.value)


def test_azure_arc_key_too_large(tmp_path):

api_version = "2019-11-01"
identity_endpoint = "http://localhost:42/token"
imds_endpoint = "http://localhost:42"
scope = "scope"
secret_key = "X" * 4097

key_file = tmp_path / "key_file.key"
key_file.write_text(secret_key)
assert key_file.read_text() == secret_key

transport = validating_transport(
requests=[
Request(
base_url=identity_endpoint,
method="GET",
required_headers={"Metadata": "true"},
required_params={"api-version": api_version, "resource": scope},
),
],
responses=[
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
],
)

with mock.patch(
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
with pytest.raises(ClientAuthenticationError) as ex:
ManagedIdentityCredential(transport=transport).get_token(scope)
assert "file size" in str(ex.value)


def test_azure_arc_key_not_exist(tmp_path):

api_version = "2019-11-01"
identity_endpoint = "http://localhost:42/token"
imds_endpoint = "http://localhost:42"
scope = "scope"

transport = validating_transport(
requests=[
Request(
base_url=identity_endpoint,
method="GET",
required_headers={"Metadata": "true"},
required_params={"api-version": api_version, "resource": scope},
),
],
responses=[
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=/path/to/key_file"}),
],
)

with mock.patch(
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
with pytest.raises(ClientAuthenticationError) as ex:
ManagedIdentityCredential(transport=transport).get_token(scope)
assert "not exist" in str(ex.value)


def test_azure_arc_key_invalid(tmp_path):

api_version = "2019-11-01"
identity_endpoint = "http://localhost:42/token"
imds_endpoint = "http://localhost:42"
scope = "scope"
key_file = tmp_path / "key_file.txt"
key_file.write_text("secret")

transport = validating_transport(
requests=[
Request(
base_url=identity_endpoint,
method="GET",
required_headers={"Metadata": "true"},
required_params={"api-version": api_version, "resource": scope},
),
Request(
base_url=identity_endpoint,
method="GET",
required_headers={"Metadata": "true"},
required_params={"api-version": api_version, "resource": scope},
),
],
responses=[
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
],
)

with mock.patch(
"os.environ",
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
):
with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
with pytest.raises(ClientAuthenticationError) as ex:
ManagedIdentityCredential(transport=transport).get_token(scope)
assert "Unexpected file path" in str(ex.value)

with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
with pytest.raises(ClientAuthenticationError) as ex:
ManagedIdentityCredential(transport=transport).get_token(scope)
assert "extension" in str(ex.value)


def test_token_exchange(tmpdir):
Expand Down
Loading