22Tests for refactored OAuth client authentication implementation.
33"""
44
5+ import base64
6+ import json
57import time
68from unittest import mock
9+ from urllib .parse import unquote
710
811import httpx
912import pytest
1013from inline_snapshot import Is , snapshot
1114from pydantic import AnyHttpUrl , AnyUrl
1215
13- from mcp .client .auth import OAuthClientProvider , PKCEParameters
16+ from mcp .client .auth import OAuthClientProvider , OAuthRegistrationError , PKCEParameters
1417from mcp .client .auth .utils import (
1518 build_oauth_authorization_server_metadata_discovery_urls ,
1619 build_protected_resource_metadata_discovery_urls ,
2124 get_client_metadata_scopes ,
2225 handle_registration_response ,
2326)
24- from mcp .shared .auth import OAuthClientInformationFull , OAuthClientMetadata , OAuthToken , ProtectedResourceMetadata
27+ from mcp .shared .auth import (
28+ OAuthClientInformationFull ,
29+ OAuthClientMetadata ,
30+ OAuthMetadata ,
31+ OAuthToken ,
32+ ProtectedResourceMetadata ,
33+ )
2534
2635
2736class MockTokenStorage :
@@ -592,6 +601,41 @@ async def test_register_client_skip_if_registered(self, oauth_provider: OAuthCli
592601 request = await oauth_provider ._register_client ()
593602 assert request is None
594603
604+ @pytest .mark .anyio
605+ async def test_register_client_none_auth_method_with_server_metadata (self , oauth_provider : OAuthClientProvider ):
606+ """Test that token_endpoint_auth_method=None selects from server's supported methods."""
607+ # Set server metadata with specific supported methods
608+ oauth_provider .context .oauth_metadata = OAuthMetadata (
609+ issuer = AnyHttpUrl ("https://auth.example.com" ),
610+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
611+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
612+ token_endpoint_auth_methods_supported = ["client_secret_post" ],
613+ )
614+ # Ensure client_metadata has None for token_endpoint_auth_method is None
615+
616+ request = await oauth_provider ._register_client ()
617+ assert request is not None
618+
619+ body = json .loads (request .content )
620+ assert body ["token_endpoint_auth_method" ] == "client_secret_post"
621+
622+ @pytest .mark .anyio
623+ async def test_register_client_none_auth_method_no_compatible (self , oauth_provider : OAuthClientProvider ):
624+ """Test that registration raises error when no compatible auth methods."""
625+ # Set server metadata with unsupported methods only
626+ oauth_provider .context .oauth_metadata = OAuthMetadata (
627+ issuer = AnyHttpUrl ("https://auth.example.com" ),
628+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
629+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
630+ token_endpoint_auth_methods_supported = ["private_key_jwt" , "client_secret_jwt" ],
631+ )
632+
633+ with pytest .raises (OAuthRegistrationError ) as exc_info :
634+ await oauth_provider ._register_client ()
635+
636+ assert "No compatible authentication methods" in str (exc_info .value )
637+ assert "private_key_jwt" in str (exc_info .value )
638+
595639 @pytest .mark .anyio
596640 async def test_token_exchange_request_authorization_code (self , oauth_provider : OAuthClientProvider ):
597641 """Test token exchange request building."""
@@ -600,6 +644,7 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O
600644 client_id = "test_client" ,
601645 client_secret = "test_secret" ,
602646 redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
647+ token_endpoint_auth_method = "client_secret_post" ,
603648 )
604649
605650 request = await oauth_provider ._exchange_token_authorization_code ("test_auth_code" , "test_verifier" )
@@ -625,6 +670,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
625670 client_id = "test_client" ,
626671 client_secret = "test_secret" ,
627672 redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
673+ token_endpoint_auth_method = "client_secret_post" ,
628674 )
629675
630676 request = await oauth_provider ._refresh_token ()
@@ -640,6 +686,114 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
640686 assert "client_id=test_client" in content
641687 assert "client_secret=test_secret" in content
642688
689+ @pytest .mark .anyio
690+ async def test_basic_auth_token_exchange (self , oauth_provider : OAuthClientProvider ):
691+ """Test token exchange with client_secret_basic authentication."""
692+ # Set up OAuth metadata to support basic auth
693+ oauth_provider .context .oauth_metadata = OAuthMetadata (
694+ issuer = AnyHttpUrl ("https://auth.example.com" ),
695+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
696+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
697+ token_endpoint_auth_methods_supported = ["client_secret_basic" , "client_secret_post" ],
698+ )
699+
700+ client_id_raw = "test@client" # Include special character to test URL encoding
701+ client_secret_raw = "test:secret" # Include colon to test URL encoding
702+
703+ oauth_provider .context .client_info = OAuthClientInformationFull (
704+ client_id = client_id_raw ,
705+ client_secret = client_secret_raw ,
706+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
707+ token_endpoint_auth_method = "client_secret_basic" ,
708+ )
709+
710+ request = await oauth_provider ._exchange_token ("test_auth_code" , "test_verifier" )
711+
712+ # Should use basic auth (registered method)
713+ assert "Authorization" in request .headers
714+ assert request .headers ["Authorization" ].startswith ("Basic " )
715+
716+ # Decode and verify credentials are properly URL-encoded
717+ encoded_creds = request .headers ["Authorization" ][6 :] # Remove "Basic " prefix
718+ decoded = base64 .b64decode (encoded_creds ).decode ()
719+ client_id , client_secret = decoded .split (":" , 1 )
720+
721+ # Check URL encoding was applied
722+ assert client_id == "test%40client" # @ should be encoded as %40
723+ assert client_secret == "test%3Asecret" # : should be encoded as %3A
724+
725+ # Verify decoded values match original
726+ assert unquote (client_id ) == client_id_raw
727+ assert unquote (client_secret ) == client_secret_raw
728+
729+ # client_secret should NOT be in body for basic auth
730+ content = request .content .decode ()
731+ assert "client_secret=" not in content
732+ assert "client_id=test%40client" in content # client_id still in body
733+
734+ @pytest .mark .anyio
735+ async def test_basic_auth_refresh_token (self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken ):
736+ """Test token refresh with client_secret_basic authentication."""
737+ oauth_provider .context .current_tokens = valid_tokens
738+
739+ # Set up OAuth metadata to only support basic auth
740+ oauth_provider .context .oauth_metadata = OAuthMetadata (
741+ issuer = AnyHttpUrl ("https://auth.example.com" ),
742+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
743+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
744+ token_endpoint_auth_methods_supported = ["client_secret_basic" ],
745+ )
746+
747+ client_id = "test_client"
748+ client_secret = "test_secret"
749+ oauth_provider .context .client_info = OAuthClientInformationFull (
750+ client_id = client_id ,
751+ client_secret = client_secret ,
752+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
753+ token_endpoint_auth_method = "client_secret_basic" ,
754+ )
755+
756+ request = await oauth_provider ._refresh_token ()
757+
758+ assert "Authorization" in request .headers
759+ assert request .headers ["Authorization" ].startswith ("Basic " )
760+
761+ encoded_creds = request .headers ["Authorization" ][6 :]
762+ decoded = base64 .b64decode (encoded_creds ).decode ()
763+ assert decoded == f"{ client_id } :{ client_secret } "
764+
765+ # client_secret should NOT be in body
766+ content = request .content .decode ()
767+ assert "client_secret=" not in content
768+
769+ @pytest .mark .anyio
770+ async def test_none_auth_method (self , oauth_provider : OAuthClientProvider ):
771+ """Test 'none' authentication method (public client)."""
772+ oauth_provider .context .oauth_metadata = OAuthMetadata (
773+ issuer = AnyHttpUrl ("https://auth.example.com" ),
774+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
775+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
776+ token_endpoint_auth_methods_supported = ["none" ],
777+ )
778+
779+ client_id = "public_client"
780+ oauth_provider .context .client_info = OAuthClientInformationFull (
781+ client_id = client_id ,
782+ client_secret = None , # No secret for public client
783+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
784+ token_endpoint_auth_method = "none" ,
785+ )
786+
787+ request = await oauth_provider ._exchange_token ("test_auth_code" , "test_verifier" )
788+
789+ # Should NOT have Authorization header
790+ assert "Authorization" not in request .headers
791+
792+ # Should NOT have client_secret in body
793+ content = request .content .decode ()
794+ assert "client_secret=" not in content
795+ assert "client_id=public_client" in content
796+
643797
644798class TestProtectedResourceMetadata :
645799 """Test protected resource handling."""
0 commit comments