Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add environment variable for redirecting IMDS token requests #18967

Merged
merged 4 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class EnvironmentVariables:
AZURE_PASSWORD = "AZURE_PASSWORD"
USERNAME_PASSWORD_VARS = (AZURE_CLIENT_ID, AZURE_USERNAME, AZURE_PASSWORD)

AZURE_POD_IDENTITY_TOKEN_URL = "AZURE_POD_IDENTITY_TOKEN_URL"
IDENTITY_ENDPOINT = "IDENTITY_ENDPOINT"
IDENTITY_HEADER = "IDENTITY_HEADER"
IDENTITY_SERVER_THUMBPRINT = "IDENTITY_SERVER_THUMBPRINT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import logging
import os
from typing import TYPE_CHECKING

import six
Expand All @@ -11,6 +12,7 @@
from azure.core.pipeline.transport import HttpRequest

from .. import CredentialUnavailableError
from .._constants import EnvironmentVariables
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.managed_identity_client import ManagedIdentityClient

Expand All @@ -34,7 +36,7 @@


def get_request(scope, identity_config):
request = HttpRequest("GET", IMDS_URL)
request = HttpRequest("GET", os.environ.get(EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL, IMDS_URL))
request.format_parameters(dict({"api-version": "2018-02-01", "resource": scope}, **identity_config))
return request

Expand All @@ -45,7 +47,10 @@ def __init__(self, **kwargs):
super(ImdsCredential, self).__init__()

self._client = ManagedIdentityClient(get_request, **dict(PIPELINE_SETTINGS, **kwargs))
self._endpoint_available = None # type: Optional[bool]
if EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL in os.environ:
self._endpoint_available = True # type: Optional[bool]
else:
self._endpoint_available = None
self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs

def _acquire_token_silently(self, *scopes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# Licensed under the MIT License.
# ------------------------------------
import logging
import os
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError, HttpResponseError

from ... import CredentialUnavailableError
from ..._constants import EnvironmentVariables
from .._internal import AsyncContextManager
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.managed_identity_client import AsyncManagedIdentityClient
Expand All @@ -25,7 +27,10 @@ def __init__(self, **kwargs: "Any") -> None:
super().__init__()

self._client = AsyncManagedIdentityClient(get_request, **PIPELINE_SETTINGS, **kwargs)
self._endpoint_available = None # type: Optional[bool]
if EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL in os.environ:
self._endpoint_available = True # type: Optional[bool]
else:
self._endpoint_available = None
self._user_assigned_identity = "client_id" in kwargs or "identity_config" in kwargs

async def close(self) -> None:
Expand Down
38 changes: 38 additions & 0 deletions sdk/identity/azure-identity/tests/test_imds_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from azure.core.exceptions import ClientAuthenticationError

from azure.identity import CredentialUnavailableError
from azure.identity._constants import EnvironmentVariables
from azure.identity._credentials.imds import ImdsCredential, IMDS_URL, PIPELINE_SETTINGS
from azure.identity._internal.user_agent import USER_AGENT
import pytest
Expand Down Expand Up @@ -176,6 +177,43 @@ def test_identity_config():
assert token == expected_token


def test_imds_url_override():
url = "https://localhost/token"
expected_token = "***"
scope = "scope"
now = int(time.time())

transport = validating_transport(
requests=[
Request(
base_url=url,
method="GET",
required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
required_params={"api-version": "2018-02-01", "resource": scope},
),
],
responses=[
mock_response(
json_payload={
"access_token": expected_token,
"expires_in": 42,
"expires_on": now + 42,
"ext_expires_in": 42,
"not_before": now,
"resource": scope,
"token_type": "Bearer",
}
),
],
)

with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL: url}, clear=True):
credential = ImdsCredential(transport=transport)
token = credential.get_token(scope)

assert token.token == expected_token


@pytest.mark.usefixtures("record_imds_test")
class RecordedTests(RecordedTestCase):
def test_system_assigned(self):
Expand Down
40 changes: 39 additions & 1 deletion sdk/identity/azure-identity/tests/test_imds_credential_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import CredentialUnavailableError
from azure.identity._constants import EnvironmentVariables
from azure.identity._credentials.imds import IMDS_URL
from azure.identity._internal.user_agent import USER_AGENT
from azure.identity.aio._credentials.imds import ImdsCredential, PIPELINE_SETTINGS
from azure.identity._credentials.imds import IMDS_URL
import pytest

from helpers import mock_response, Request
Expand Down Expand Up @@ -211,6 +212,43 @@ async def test_identity_config():
assert token == expected_token


async def test_imds_url_override():
url = "https://localhost/token"
expected_token = "***"
scope = "scope"
now = int(time.time())

transport = async_validating_transport(
requests=[
Request(
base_url=url,
method="GET",
required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
required_params={"api-version": "2018-02-01", "resource": scope},
),
],
responses=[
mock_response(
json_payload={
"access_token": expected_token,
"expires_in": 42,
"expires_on": now + 42,
"ext_expires_in": 42,
"not_before": now,
"resource": scope,
"token_type": "Bearer",
}
),
],
)

with mock.patch.dict("os.environ", {EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL: url}, clear=True):
credential = ImdsCredential(transport=transport)
token = await credential.get_token(scope)

assert token.token == expected_token


@pytest.mark.usefixtures("record_imds_test")
class RecordedTests(RecordedTestCase):
@await_test
Expand Down