Skip to content

Merge release 1.32.3 back to dev branch #816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def __init__(
except (
FileNotFoundError, # Or IOError in Python 2
pickle.UnpicklingError, # A corrupted http cache file
AttributeError, # Cache created by a different version of MSAL
):
persisted_http_cache = {} # Recover by starting afresh
atexit.register(lambda: pickle.dump(
Expand Down
14 changes: 9 additions & 5 deletions msal/individual_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def __init__(self, mapping=None, capacity=None, expires_in=None, lock=None,
self._expires_in = expires_in
self._lock = Lock() if lock is None else lock

def _peek(self):
# Returns (sequence, timestamps) without triggering maintenance
return self._mapping.get(self._INDEX, ([], {}))

def _validate_key(self, key):
if key == self._INDEX:
raise ValueError("key {} is a reserved keyword in {}".format(
Expand All @@ -85,7 +89,7 @@ def _set(self, key, value, expires_in):
# This internal implementation powers both set() and __setitem__(),
# so that they don't depend on each other.
self._validate_key(key)
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
sequence, timestamps = self._peek()
self._maintenance(sequence, timestamps) # O(logN)
now = int(time.time())
expires_at = now + expires_in
Expand Down Expand Up @@ -136,7 +140,7 @@ def __getitem__(self, key): # O(1)
self._validate_key(key)
with self._lock:
# Skip self._maintenance(), because it would need O(logN) time
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
sequence, timestamps = self._peek()
expires_at, created_at = timestamps[key] # Would raise KeyError accordingly
now = int(time.time())
if not created_at <= now < expires_at:
Expand All @@ -155,22 +159,22 @@ def __delitem__(self, key): # O(1)
with self._lock:
# Skip self._maintenance(), because it would need O(logN) time
self._mapping.pop(key, None) # O(1)
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
sequence, timestamps = self._peek()
del timestamps[key] # O(1)
self._mapping[self._INDEX] = sequence, timestamps

def __len__(self): # O(logN)
"""Drop all expired items and return the remaining length"""
with self._lock:
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
sequence, timestamps = self._peek()
self._maintenance(sequence, timestamps) # O(logN)
self._mapping[self._INDEX] = sequence, timestamps
return len(timestamps) # Faster than iter(self._mapping) when it is on disk

def __iter__(self):
"""Drop all expired items and return an iterator of the remaining items"""
with self._lock:
sequence, timestamps = self._mapping.get(self._INDEX, ([], {}))
sequence, timestamps = self._peek()
self._maintenance(sequence, timestamps) # O(logN)
self._mapping[self._INDEX] = sequence, timestamps
return iter(timestamps) # Faster than iter(self._mapping) when it is on disk
Expand Down
9 changes: 4 additions & 5 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def __init__(self, *, client_id=None, resource_id=None, object_id=None):


class _ThrottledHttpClient(ThrottledHttpClientBase):
def __init__(self, http_client, **kwargs):
super(_ThrottledHttpClient, self).__init__(http_client, **kwargs)
def __init__(self, *args, **kwargs):
super(_ThrottledHttpClient, self).__init__(*args, **kwargs)
self.get = IndividualCache( # All MIs (except Cloud Shell) use GETs
mapping=self._expiring_mapping,
key_maker=lambda func, args, kwargs: "REQ {} hash={} 429/5xx/Retry-After".format(
Expand All @@ -124,7 +124,7 @@ def __init__(self, http_client, **kwargs):
str(kwargs.get("params")) + str(kwargs.get("data"))),
),
expires_in=RetryAfterParser(5).parse, # 5 seconds default for non-PCA
)(http_client.get)
)(self.get) # Note: Decorate the parent get(), not the http_client.get()


class ManagedIdentityClient(object):
Expand Down Expand Up @@ -233,8 +233,7 @@ def __init__(
# (especially for 410 which was supposed to be a permanent failure).
# 2. MI on Service Fabric specifically suggests to not retry on 404.
# ( https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-cluster-managed-identity-service-fabric-app-code#error-handling )
http_client.http_client # Patch the raw (unpatched) http client
if isinstance(http_client, ThrottledHttpClientBase) else http_client,
http_client,
http_cache=http_cache,
)
self._token_cache = token_cache or TokenCache()
Expand Down
2 changes: 1 addition & 1 deletion msal/sku.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"""

# The __init__.py will import this. Not the other way around.
__version__ = "1.32.0"
__version__ = "1.32.3"
SKU = "MSAL.Python"
88 changes: 58 additions & 30 deletions msal/throttled_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,34 @@

from .individual_cache import _IndividualCache as IndividualCache
from .individual_cache import _ExpiringMapping as ExpiringMapping
from .oauth2cli.http import Response
from .exceptions import MsalServiceError


# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"


def _get_headers(response):
# MSAL's HttpResponse did not have headers until 1.23.0
# https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61
# This helper ensures graceful degradation to {} without exception
return getattr(response, "headers", {})


class RetryAfterParser(object):
FIELD_NAME_LOWER = "Retry-After".lower()
def __init__(self, default_value=None):
self._default_value = 5 if default_value is None else default_value

def parse(self, *, result, **ignored):
"""Return seconds to throttle"""
response = result
lowercase_headers = {k.lower(): v for k, v in getattr(
# Historically, MSAL's HttpResponse does not always have headers
response, "headers", {}).items()}
lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()}
if not (response.status_code == 429 or response.status_code >= 500
or "retry-after" in lowercase_headers):
or self.FIELD_NAME_LOWER in lowercase_headers):
return 0 # Quick exit
retry_after = lowercase_headers.get("retry-after", self._default_value)
retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value)
try:
# AAD's retry_after uses integer format only
# https://stackoverflow.microsoft.com/questions/264931/264932
Expand All @@ -37,27 +45,55 @@ def _extract_data(kwargs, key, default=None):
return data.get(key) if isinstance(data, dict) else default


class NormalizedResponse(Response):
"""A http response with the shape defined in Response,
but contains only the data we will store in cache.
"""
def __init__(self, raw_response):
super().__init__()
self.status_code = raw_response.status_code
self.text = raw_response.text
self.headers = {
k.lower(): v for k, v in _get_headers(raw_response).items()
# Attempted storing only a small set of headers (such as Retry-After),
# but it tends to lead to missing information (such as WWW-Authenticate).
# So we store all headers, which are expected to contain only public info,
# because we throttle only error responses and public responses.
}

## Note: Don't use the following line,
## because when being pickled, it will indirectly pickle the whole raw_response
# self.raise_for_status = raw_response.raise_for_status
def raise_for_status(self):
if self.status_code >= 400:
raise MsalServiceError("HTTP Error: {}".format(self.status_code))


class ThrottledHttpClientBase(object):
"""Throttle the given http_client by storing and retrieving data from cache.

This wrapper exists so that our patching post() and get() would prevent
re-patching side effect when/if same http_client being reused.
This base exists so that:
1. These base post() and get() will return a NormalizedResponse
2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.

The subclass should implement post() and/or get()
Subclasses shall only need to dynamically decorate their post() and get() methods
in their __init__() method.
"""
def __init__(self, http_client, *, http_cache=None):
self.http_client = http_client
self.http_client = http_client.http_client if isinstance(
# If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
http_client, ThrottledHttpClientBase) else http_client
self._expiring_mapping = ExpiringMapping( # It will automatically clean up
mapping=http_cache if http_cache is not None else {},
capacity=1024, # To prevent cache blowing up especially for CCA
lock=Lock(), # TODO: This should ideally also allow customization
)

def post(self, *args, **kwargs):
return self.http_client.post(*args, **kwargs)
return NormalizedResponse(self.http_client.post(*args, **kwargs))

def get(self, *args, **kwargs):
return self.http_client.get(*args, **kwargs)
return NormalizedResponse(self.http_client.get(*args, **kwargs))

def close(self):
return self.http_client.close()
Expand All @@ -68,12 +104,11 @@ def _hash(raw):


class ThrottledHttpClient(ThrottledHttpClientBase):
def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
super(ThrottledHttpClient, self).__init__(http_client, **kwargs)

_post = http_client.post # We'll patch _post, and keep original post() intact

_post = IndividualCache(
"""A throttled http client that is used by MSAL's non-managed identity clients."""
def __init__(self, *args, default_throttle_time=None, **kwargs):
"""Decorate self.post() and self.get() dynamically"""
super(ThrottledHttpClient, self).__init__(*args, **kwargs)
self.post = IndividualCache(
# Internal specs requires throttling on at least token endpoint,
# here we have a generic patch for POST on all endpoints.
mapping=self._expiring_mapping,
Expand All @@ -91,9 +126,9 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
_extract_data(kwargs, "username")))), # "account" of ROPC
),
expires_in=RetryAfterParser(default_throttle_time or 5).parse,
)(_post)
)(self.post)

_post = IndividualCache( # It covers the "UI required cache"
self.post = IndividualCache( # It covers the "UI required cache"
mapping=self._expiring_mapping,
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
args[0], # It is the url, typically containing authority and tenant
Expand Down Expand Up @@ -125,12 +160,10 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
isinstance(kwargs.get("data"), dict)
and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
)
and "retry-after" not in set( # Leave it to the Retry-After decorator
h.lower() for h in getattr(result, "headers", {}).keys())
and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator
h.lower() for h in _get_headers(result))
else 0,
)(_post)

self.post = _post
)(self.post)

self.get = IndividualCache( # Typically those discovery GETs
mapping=self._expiring_mapping,
Expand All @@ -140,9 +173,4 @@ def __init__(self, http_client, *, default_throttle_time=None, **kwargs):
),
expires_in=lambda result=None, **ignored:
3600*24 if 200 <= result.status_code < 300 else 0,
)(http_client.get)

# The following 2 methods have been defined dynamically by __init__()
#def post(self, *args, **kwargs): pass
#def get(self, *args, **kwargs): pass

)(self.get)
27 changes: 22 additions & 5 deletions tests/test_individual_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
class TestExpiringMapping(unittest.TestCase):
def setUp(self):
self.mapping = {}
self.m = ExpiringMapping(mapping=self.mapping, capacity=2, expires_in=1)
self.expires_in = 1
self.m = ExpiringMapping(
mapping=self.mapping, capacity=2, expires_in=self.expires_in)

def how_many(self):
# This helper checks how many items are in the mapping, WITHOUT triggering purge
return len(self.m._peek()[1])

def test_should_disallow_accessing_reserved_keyword(self):
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -40,11 +46,21 @@ def test_iter_should_purge(self):
sleep(1)
self.assertEqual([], list(self.m))

def test_get_should_purge(self):
def test_get_should_not_purge_and_should_return_only_when_the_item_is_still_valid(self):
self.m["thing one"] = "one"
self.m["thing two"] = "two"
sleep(1)
self.assertEqual(2, self.how_many(), "We begin with 2 items")
with self.assertRaises(KeyError):
self.m["thing one"]
self.assertEqual(1, self.how_many(), "get() should not purge the remaining items")

def test_setitem_should_purge(self):
self.m["thing one"] = "one"
sleep(1)
self.m["thing two"] = "two"
self.assertEqual(1, self.how_many(), "setitem() should purge all expired items")
self.assertEqual("two", self.m["thing two"], "The remaining item should be thing two")

def test_various_expiring_time(self):
self.assertEqual(0, len(self.m))
Expand All @@ -57,12 +73,13 @@ def test_various_expiring_time(self):
def test_old_item_can_be_updated_with_new_expiry_time(self):
self.assertEqual(0, len(self.m))
self.m["thing"] = "one"
self.m.set("thing", "two", 2)
new_lifetime = 3 # 2-second seems too short and causes flakiness
self.m.set("thing", "two", new_lifetime)
self.assertEqual(1, len(self.m), "It contains 1 item")
self.assertEqual("two", self.m["thing"], 'Already been updated to "two"')
sleep(1)
sleep(self.expires_in)
self.assertEqual("two", self.m["thing"], "Not yet expires")
sleep(1)
sleep(new_lifetime - self.expires_in)
self.assertEqual(0, len(self.m))

def test_oversized_input_should_purge_most_aging_item(self):
Expand Down
35 changes: 34 additions & 1 deletion tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
from mock import patch, ANY, mock_open, Mock
import requests

from tests.http_client import MinimalResponse
from tests.test_throttled_http_client import (
MinimalResponse, ThrottledHttpClientBaseTestCase, DummyHttpClient)
from msal import (
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
ManagedIdentityClient,
ManagedIdentityError,
ArcPlatformNotSupportedError,
)
from msal.managed_identity import (
_ThrottledHttpClient,
_supported_arc_platforms_and_their_prefixes,
get_managed_identity_source,
APP_SERVICE,
Expand Down Expand Up @@ -49,6 +51,37 @@ def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_f
{"ManagedIdentityIdType": "SystemAssigned", "Id": None})


class ThrottledHttpClientTestCase(ThrottledHttpClientBaseTestCase):
def test_throttled_http_client_should_not_alter_original_http_client(self):
self.assertNotAlteringOriginalHttpClient(_ThrottledHttpClient)

def test_throttled_http_client_should_not_cache_successful_http_response(self):
http_cache = {}
http_client=DummyHttpClient(
status_code=200,
response_text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
)
app = ManagedIdentityClient(
SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache)
result = app.acquire_token_for_client(resource="R")
self.assertEqual("AT", result["access_token"])
self.assertEqual({}, http_cache, "Should not cache successful http response")

def test_throttled_http_client_should_cache_unsuccessful_http_response(self):
http_cache = {}
http_client=DummyHttpClient(
status_code=400,
response_headers={"Retry-After": "1"},
response_text='{"error": "invalid_request"}',
)
app = ManagedIdentityClient(
SystemAssignedManagedIdentity(), http_client=http_client, http_cache=http_cache)
result = app.acquire_token_for_client(resource="R")
self.assertEqual("invalid_request", result["error"])
self.assertNotEqual({}, http_cache, "Should cache unsuccessful http response")
self.assertCleanPickle(http_cache)


class ClientTestCase(unittest.TestCase):
maxDiff = None

Expand Down
Loading