Skip to content

Commit

Permalink
Add environment variable for redirecting IMDS token requests (#18967)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jun 3, 2021
1 parent 77e30b8 commit e918edd
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 4 deletions.
1 change: 1 addition & 0 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class EnvironmentVariables:
AZURE_PASSWORD = "AZURE_PASSWORD"
USERNAME_PASSWORD_VARS = (AZURE_CLIENT_ID, AZURE_USERNAME, AZURE_PASSWORD)

AZURE_POD_IDENTITY_TOKEN_URL = "AZURE_POD_IDENTITY_TOKEN_URL"
IDENTITY_ENDPOINT = "IDENTITY_ENDPOINT"
IDENTITY_HEADER = "IDENTITY_HEADER"
IDENTITY_SERVER_THUMBPRINT = "IDENTITY_SERVER_THUMBPRINT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import logging
import os
from typing import TYPE_CHECKING

import six
Expand All @@ -11,6 +12,7 @@
from azure.core.pipeline.transport import HttpRequest

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.managed_identity_client import ManagedIdentityClient

Expand All @@ -34,7 +36,7 @@


def get_request(scope, identity_config):
request = HttpRequest("GET", IMDS_URL)
request = HttpRequest("GET", os.environ.get(EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL, IMDS_URL))
request.format_parameters(dict({"api-version": "2018-02-01", "resource": scope}, **identity_config))
return request

Expand All @@ -45,7 +47,10 @@ def __init__(self, **kwargs):
super(ImdsCredential, self).__init__()

self._client = ManagedIdentityClient(get_request, **dict(PIPELINE_SETTINGS, **kwargs))
self._endpoint_available = None # type: Optional[bool]
if EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL in os.environ:
self._endpoint_available = True # type: Optional[bool]
else:
self._endpoint_available = None
self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs

def _acquire_token_silently(self, *scopes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# Licensed under the MIT License.
# ------------------------------------
import logging
import os
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError, HttpResponseError

from ... import CredentialUnavailableError
from ..._constants import EnvironmentVariables
from .._internal import AsyncContextManager
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.managed_identity_client import AsyncManagedIdentityClient
Expand All @@ -25,7 +27,10 @@ def __init__(self, **kwargs: "Any") -> None:
super().__init__()

self._client = AsyncManagedIdentityClient(get_request, **PIPELINE_SETTINGS, **kwargs)
self._endpoint_available = None # type: Optional[bool]
if EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL in os.environ:
self._endpoint_available = True # type: Optional[bool]
else:
self._endpoint_available = None
self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs

async def close(self) -> None:
Expand Down
38 changes: 38 additions & 0 deletions sdk/identity/azure-identity/tests/test_imds_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from azure.core.exceptions import ClientAuthenticationError

from azure.identity import CredentialUnavailableError
from azure.identity._constants import EnvironmentVariables
from azure.identity._credentials.imds import ImdsCredential, IMDS_URL, PIPELINE_SETTINGS
from azure.identity._internal.user_agent import USER_AGENT
import pytest
Expand Down Expand Up @@ -176,6 +177,43 @@ def test_identity_config():
assert token == expected_token


def test_imds_url_override():
url = "https://localhost/token"
expected_token = "***"
scope = "scope"
now = int(time.time())

transport = validating_transport(
requests=[
Request(
base_url=url,
method="GET",
required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
required_params={"api-version": "2018-02-01", "resource": scope},
),
],
responses=[
mock_response(
json_payload={
"access_token": expected_token,
"expires_in": 42,
"expires_on": now + 42,
"ext_expires_in": 42,
"not_before": now,
"resource": scope,
"token_type": "Bearer",
}
),
],
)

with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL: url}, clear=True):
credential = ImdsCredential(transport=transport)
token = credential.get_token(scope)

assert token.token == expected_token


@pytest.mark.usefixtures("record_imds_test")
class RecordedTests(RecordedTestCase):
def test_system_assigned(self):
Expand Down
40 changes: 39 additions & 1 deletion sdk/identity/azure-identity/tests/test_imds_credential_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import CredentialUnavailableError
from azure.identity._constants import EnvironmentVariables
from azure.identity._credentials.imds import IMDS_URL
from azure.identity._internal.user_agent import USER_AGENT
from azure.identity.aio._credentials.imds import ImdsCredential, PIPELINE_SETTINGS
from azure.identity._credentials.imds import IMDS_URL
import pytest

from helpers import mock_response, Request
Expand Down Expand Up @@ -211,6 +212,43 @@ async def test_identity_config():
assert token == expected_token


async def test_imds_url_override():
url = "https://localhost/token"
expected_token = "***"
scope = "scope"
now = int(time.time())

transport = async_validating_transport(
requests=[
Request(
base_url=url,
method="GET",
required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
required_params={"api-version": "2018-02-01", "resource": scope},
),
],
responses=[
mock_response(
json_payload={
"access_token": expected_token,
"expires_in": 42,
"expires_on": now + 42,
"ext_expires_in": 42,
"not_before": now,
"resource": scope,
"token_type": "Bearer",
}
),
],
)

with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL: url}, clear=True):
credential = ImdsCredential(transport=transport)
token = await credential.get_token(scope)

assert token.token == expected_token


@pytest.mark.usefixtures("record_imds_test")
class RecordedTests(RecordedTestCase):
@await_test
Expand Down

0 comments on commit e918edd

Please sign in to comment.