33# Licensed under the MIT License.
44# ------------------------------------
55import os
6+ import sys
67import time
78
89try :
@@ -883,9 +884,10 @@ def test_azure_arc(tmpdir):
883884 "os.environ" ,
884885 {EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
885886 ):
886- token = ManagedIdentityCredential (transport = transport ).get_token (scope )
887- assert token .token == access_token
888- assert token .expires_on == expires_on
887+ with mock .patch ("azure.identity._credentials.azure_arc._validate_key_file" , lambda x : None ):
888+ token = ManagedIdentityCredential (transport = transport ).get_token (scope )
889+ assert token .token == access_token
890+ assert token .expires_on == expires_on
889891
890892
891893def test_azure_arc_tenant_id (tmpdir ):
@@ -936,9 +938,10 @@ def test_azure_arc_tenant_id(tmpdir):
936938 "os.environ" ,
937939 {EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
938940 ):
939- token = ManagedIdentityCredential (transport = transport ).get_token (scope , tenant_id = "tenant_id" )
940- assert token .token == access_token
941- assert token .expires_on == expires_on
941+ with mock .patch ("azure.identity._credentials.azure_arc._validate_key_file" , lambda x : None ):
942+ token = ManagedIdentityCredential (transport = transport ).get_token (scope , tenant_id = "tenant_id" )
943+ assert token .token == access_token
944+ assert token .expires_on == expires_on
942945
943946
944947def test_azure_arc_client_id ():
@@ -950,10 +953,123 @@ def test_azure_arc_client_id():
950953 EnvironmentVariables .IMDS_ENDPOINT : "http://localhost:42" ,
951954 },
952955 ):
953- credential = ManagedIdentityCredential (client_id = "some-guid" )
956+ with mock .patch ("azure.identity._credentials.azure_arc._validate_key_file" , lambda x : None ):
957+ credential = ManagedIdentityCredential (client_id = "some-guid" )
954958
955- with pytest .raises (ClientAuthenticationError ):
959+ with pytest .raises (ClientAuthenticationError ) as ex :
956960 credential .get_token ("scope" )
961+ assert "not supported" in str (ex .value )
962+
963+
964+ def test_azure_arc_key_too_large (tmp_path ):
965+
966+ api_version = "2019-11-01"
967+ identity_endpoint = "http://localhost:42/token"
968+ imds_endpoint = "http://localhost:42"
969+ scope = "scope"
970+ secret_key = "X" * 4097
971+
972+ key_file = tmp_path / "key_file.key"
973+ key_file .write_text (secret_key )
974+ assert key_file .read_text () == secret_key
975+
976+ transport = validating_transport (
977+ requests = [
978+ Request (
979+ base_url = identity_endpoint ,
980+ method = "GET" ,
981+ required_headers = {"Metadata" : "true" },
982+ required_params = {"api-version" : api_version , "resource" : scope },
983+ ),
984+ ],
985+ responses = [
986+ mock_response (status_code = 401 , headers = {"WWW-Authenticate" : "Basic realm={}" .format (key_file )}),
987+ ],
988+ )
989+
990+ with mock .patch (
991+ "os.environ" ,
992+ {EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
993+ ):
994+ with mock .patch ("azure.identity._credentials.azure_arc._get_key_file_path" , lambda : str (tmp_path )):
995+ with pytest .raises (ClientAuthenticationError ) as ex :
996+ ManagedIdentityCredential (transport = transport ).get_token (scope )
997+ assert "file size" in str (ex .value )
998+
999+
1000+ def test_azure_arc_key_not_exist (tmp_path ):
1001+
1002+ api_version = "2019-11-01"
1003+ identity_endpoint = "http://localhost:42/token"
1004+ imds_endpoint = "http://localhost:42"
1005+ scope = "scope"
1006+
1007+ transport = validating_transport (
1008+ requests = [
1009+ Request (
1010+ base_url = identity_endpoint ,
1011+ method = "GET" ,
1012+ required_headers = {"Metadata" : "true" },
1013+ required_params = {"api-version" : api_version , "resource" : scope },
1014+ ),
1015+ ],
1016+ responses = [
1017+ mock_response (status_code = 401 , headers = {"WWW-Authenticate" : "Basic realm=/path/to/key_file" }),
1018+ ],
1019+ )
1020+
1021+ with mock .patch (
1022+ "os.environ" ,
1023+ {EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
1024+ ):
1025+ with pytest .raises (ClientAuthenticationError ) as ex :
1026+ ManagedIdentityCredential (transport = transport ).get_token (scope )
1027+ assert "not exist" in str (ex .value )
1028+
1029+
1030+ def test_azure_arc_key_invalid (tmp_path ):
1031+
1032+ api_version = "2019-11-01"
1033+ identity_endpoint = "http://localhost:42/token"
1034+ imds_endpoint = "http://localhost:42"
1035+ scope = "scope"
1036+ key_file = tmp_path / "key_file.txt"
1037+ key_file .write_text ("secret" )
1038+
1039+ transport = validating_transport (
1040+ requests = [
1041+ Request (
1042+ base_url = identity_endpoint ,
1043+ method = "GET" ,
1044+ required_headers = {"Metadata" : "true" },
1045+ required_params = {"api-version" : api_version , "resource" : scope },
1046+ ),
1047+ Request (
1048+ base_url = identity_endpoint ,
1049+ method = "GET" ,
1050+ required_headers = {"Metadata" : "true" },
1051+ required_params = {"api-version" : api_version , "resource" : scope },
1052+ ),
1053+ ],
1054+ responses = [
1055+ mock_response (status_code = 401 , headers = {"WWW-Authenticate" : "Basic realm={}" .format (key_file )}),
1056+ mock_response (status_code = 401 , headers = {"WWW-Authenticate" : "Basic realm={}" .format (key_file )}),
1057+ ],
1058+ )
1059+
1060+ with mock .patch (
1061+ "os.environ" ,
1062+ {EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
1063+ ):
1064+ with mock .patch ("azure.identity._credentials.azure_arc._get_key_file_path" , lambda : "/foo" ):
1065+ with pytest .raises (ClientAuthenticationError ) as ex :
1066+ ManagedIdentityCredential (transport = transport ).get_token (scope )
1067+ assert "Unexpected file path" in str (ex .value )
1068+
1069+ with mock .patch ("azure.identity._credentials.azure_arc._get_key_file_path" , lambda : str (tmp_path )):
1070+ with pytest .raises (ClientAuthenticationError ) as ex :
1071+ ManagedIdentityCredential (transport = transport ).get_token (scope )
1072+ assert "extension" in str (ex .value )
9571073
9581074
9591075def test_token_exchange (tmpdir ):
0 commit comments