diff --git a/msal/managed_identity.py b/msal/managed_identity.py index bbd4f52f..608bc1bf 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -40,7 +40,7 @@ class ManagedIdentity(UserDict): _types_mapping = { # Maps type name in configuration to type name on wire CLIENT_ID: "client_id", - RESOURCE_ID: "mi_res_id", + RESOURCE_ID: "msi_res_id", # VM's IMDS prefers msi_res_id https://github.com/Azure/azure-rest-api-specs/blob/dba6ed1f03bda88ac6884c0a883246446cc72495/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable/2018-10-01/imds.json#L233-L239 OBJECT_ID: "object_id", } @@ -430,9 +430,9 @@ def _obtain_token(http_client, managed_identity, resource): return _obtain_token_on_azure_vm(http_client, managed_identity, resource) -def _adjust_param(params, managed_identity): +def _adjust_param(params, managed_identity, types_mapping=None): # Modify the params dict in place - id_name = ManagedIdentity._types_mapping.get( + id_name = (types_mapping or ManagedIdentity._types_mapping).get( managed_identity.get(ManagedIdentity.ID_TYPE)) if id_name: params[id_name] = managed_identity[ManagedIdentity.ID] @@ -479,7 +479,12 @@ def _obtain_token_on_app_service( "api-version": "2019-08-01", "resource": resource, } - _adjust_param(params, managed_identity) + _adjust_param(params, managed_identity, types_mapping={ + ManagedIdentity.CLIENT_ID: "client_id", + ManagedIdentity.RESOURCE_ID: "mi_res_id", # App Service's resource id uses "mi_res_id" + ManagedIdentity.OBJECT_ID: "object_id", + }) + resp = http_client.get( endpoint, params=params, diff --git a/tests/test_mi.py b/tests/test_mi.py index ec2803ca..1f33fe73 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -139,6 +139,22 @@ def test_vm_error_should_be_returned_as_is(self): json.loads(raw_error), self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_vm_resource_id_parameter_should_be_msi_res_id(self): + app = ManagedIdentityClient( + {"ManagedIdentityIdType": "ResourceId", "Id": "1234"}, + http_client=requests.Session(), + ) + with patch.object(app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_in": 3600, "resource": "R"}', + )) as mocked_method: + app.acquire_token_for_client(resource="R") + mocked_method.assert_called_with( + 'http://169.254.169.254/metadata/identity/oauth2/token', + params={'api-version': '2018-02-01', 'resource': 'R', 'msi_res_id': '1234'}, + headers={'Metadata': 'true'}, + ) + @patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"}) class AppServiceTestCase(ClientTestCase): @@ -164,6 +180,22 @@ def test_app_service_error_should_be_normalized(self): }, self.app.acquire_token_for_client(resource="R")) self.assertEqual({}, self.app._token_cache._cache) + def test_app_service_resource_id_parameter_should_be_mi_res_id(self): + app = ManagedIdentityClient( + {"ManagedIdentityIdType": "ResourceId", "Id": "1234"}, + http_client=requests.Session(), + ) + with patch.object(app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_on": 12345, "resource": "R"}', + )) as mocked_method: + app.acquire_token_for_client(resource="R") + mocked_method.assert_called_with( + 'http://localhost', + params={'api-version': '2019-08-01', 'resource': 'R', 'mi_res_id': '1234'}, + headers={'X-IDENTITY-HEADER': 'foo', 'Metadata': 'true'}, + ) + @patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"}) class MachineLearningTestCase(ClientTestCase):