Skip to content

Commit

Permalink
Error out on invalid ManagedIdentity dict
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Sep 6, 2024
1 parent 85c93f8 commit 0a756e9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
8 changes: 6 additions & 2 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ class ManagedIdentity(UserDict):

@classmethod
def is_managed_identity(cls, unknown):
return isinstance(unknown, ManagedIdentity) or (
isinstance(unknown, dict) and cls.ID_TYPE in unknown)
return (isinstance(unknown, ManagedIdentity)
or cls.is_system_assigned(unknown)
or cls.is_user_assigned(unknown))

@classmethod
def is_system_assigned(cls, unknown):
Expand Down Expand Up @@ -217,6 +218,9 @@ def __init__(
)
token = client.acquire_token_for_client("resource")
"""
if not ManagedIdentity.is_managed_identity(managed_identity):
raise ManagedIdentityError(
f"Incorrect managed_identity: {managed_identity}")
self._managed_identity = managed_identity
self._http_client = _ThrottledHttpClient(
# This class only throttles excess token acquisition requests.
Expand Down
11 changes: 11 additions & 0 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def setUp(self):
http_client=requests.Session(),
)

def test_error_out_on_invalid_input(self):
with self.assertRaises(ManagedIdentityError):
ManagedIdentityClient({"foo": "bar"}, http_client=requests.Session())
with self.assertRaises(ManagedIdentityError):
ManagedIdentityClient(
{"ManagedIdentityIdType": "undefined", "Id": "foo"},
http_client=requests.Session())

def assertCacheStatus(self, app):
cache = app._token_cache._cache
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
Expand Down Expand Up @@ -241,6 +249,9 @@ class ArcTestCase(ClientTestCase):
"WWW-Authenticate": "Basic realm=/tmp/foo",
})

def test_error_out_on_invalid_input(self, mocked_stat):
return super(ArcTestCase, self).test_error_out_on_invalid_input()

def test_happy_path(self, mocked_stat):
expires_in = 1234
with patch.object(self.app._http_client, "get", side_effect=[
Expand Down

0 comments on commit 0a756e9

Please sign in to comment.