Skip to content

Commit fe0e62e

Browse files
jonsheapcarleton
authored andcommitted
Add client_secret_basic auth support to MCP client
- Implement HTTP Basic auth for OAuth token requests - Automatically sets selects auth method when OAuthClientProvider is configured with OAuthClientMetadata that has token_endpoint_auth_method=None. - Made OAuthClientMetadata.token_endpoint_auth_method optional to support the above auto-configuration. - Removed ` "token_endpoint_auth_method": "client_secret_post"` from the simple-auth-client example as is now auto-configured.
1 parent aa50976 commit fe0e62e

File tree

5 files changed

+228
-15
lines changed

5 files changed

+228
-15
lines changed

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ async def callback_handler() -> tuple[str, str | None]:
177177
"redirect_uris": ["http://localhost:3030/callback"],
178178
"grant_types": ["authorization_code", "refresh_token"],
179179
"response_types": ["code"],
180-
"token_endpoint_auth_method": "client_secret_post",
181180
}
182181

183182
async def _default_redirect_handler(authorization_url: str) -> None:

src/mcp/client/auth/oauth2.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from collections.abc import AsyncGenerator, Awaitable, Callable
1414
from dataclasses import dataclass, field
1515
from typing import Any, Protocol
16-
from urllib.parse import urlencode, urljoin, urlparse
16+
from urllib.parse import quote, urlencode, urljoin, urlparse
1717

1818
import anyio
1919
import httpx
@@ -173,6 +173,42 @@ def should_include_resource_param(self, protocol_version: str | None = None) ->
173173
# Version format is YYYY-MM-DD, so string comparison works
174174
return protocol_version >= "2025-06-18"
175175

176+
def prepare_token_auth(
177+
self, data: dict[str, str], headers: dict[str, str] | None = None
178+
) -> tuple[dict[str, str], dict[str, str]]:
179+
"""Prepare authentication for token requests.
180+
181+
Args:
182+
data: The form data to send
183+
headers: Optional headers dict to update
184+
185+
Returns:
186+
Tuple of (updated_data, updated_headers)
187+
"""
188+
if headers is None:
189+
headers = {}
190+
191+
if not self.client_info:
192+
return data, headers
193+
194+
auth_method = self.client_info.token_endpoint_auth_method
195+
196+
if auth_method == "client_secret_basic" and self.client_info.client_secret:
197+
# URL-encode client ID and secret per RFC 6749 Section 2.3.1
198+
encoded_id = quote(self.client_info.client_id, safe="")
199+
encoded_secret = quote(self.client_info.client_secret, safe="")
200+
credentials = f"{encoded_id}:{encoded_secret}"
201+
encoded_credentials = base64.b64encode(credentials.encode()).decode()
202+
headers["Authorization"] = f"Basic {encoded_credentials}"
203+
# Don't include client_secret in body for basic auth
204+
data = {k: v for k, v in data.items() if k != "client_secret"}
205+
elif auth_method == "client_secret_post" and self.client_info.client_secret:
206+
# Include client_secret in request body
207+
data["client_secret"] = self.client_info.client_secret
208+
# For auth_method == "none", don't add any client_secret
209+
210+
return data, headers
211+
176212

177213
class OAuthClientProvider(httpx.Auth):
178214
"""
@@ -247,6 +283,27 @@ async def _register_client(self) -> httpx.Request | None:
247283

248284
registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
249285

286+
# If token_endpoint_auth_method is None, auto-select based on server support
287+
if self.context.client_metadata.token_endpoint_auth_method is None:
288+
preference_order = ["client_secret_basic", "client_secret_post", "none"]
289+
290+
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint_auth_methods_supported:
291+
supported = self.context.oauth_metadata.token_endpoint_auth_methods_supported
292+
for method in preference_order:
293+
if method in supported:
294+
registration_data["token_endpoint_auth_method"] = method
295+
break
296+
else:
297+
# No compatible methods between client and server
298+
raise OAuthRegistrationError(
299+
f"No compatible authentication methods. "
300+
f"Server supports: {supported}, "
301+
f"Client supports: {preference_order}"
302+
)
303+
else:
304+
# No server metadata available, use our default preference
305+
registration_data["token_endpoint_auth_method"] = preference_order[0]
306+
250307
return httpx.Request(
251308
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
252309
)
@@ -343,12 +400,11 @@ async def _exchange_token_authorization_code(
343400
if self.context.should_include_resource_param(self.context.protocol_version):
344401
token_data["resource"] = self.context.get_resource_url() # RFC 8707
345402

346-
if self.context.client_info.client_secret:
347-
token_data["client_secret"] = self.context.client_info.client_secret
403+
# Prepare authentication based on preferred method
404+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
405+
token_data, headers = self.context.prepare_token_auth(token_data, headers)
348406

349-
return httpx.Request(
350-
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
351-
)
407+
return httpx.Request("POST", token_url, data=token_data, headers=headers)
352408

353409
async def _handle_token_response(self, response: httpx.Response) -> None:
354410
"""Handle token exchange response."""
@@ -389,12 +445,11 @@ async def _refresh_token(self) -> httpx.Request:
389445
if self.context.should_include_resource_param(self.context.protocol_version):
390446
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
391447

392-
if self.context.client_info.client_secret: # pragma: no branch
393-
refresh_data["client_secret"] = self.context.client_info.client_secret
448+
# Prepare authentication based on preferred method
449+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
450+
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
394451

395-
return httpx.Request(
396-
"POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
397-
)
452+
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
398453

399454
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
400455
"""Handle token refresh response. Returns True if successful."""

src/mcp/server/auth/handlers/register.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ async def handle(self, request: Request) -> Response:
4949
)
5050

5151
client_id = str(uuid4())
52+
53+
# If auth method is None, default to client_secret_post
54+
if client_metadata.token_endpoint_auth_method is None:
55+
client_metadata.token_endpoint_auth_method = "client_secret_post"
56+
5257
client_secret = None
5358
if client_metadata.token_endpoint_auth_method != "none": # pragma: no branch
5459
# cryptographically secure random 32-byte hex string

src/mcp/shared/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class OAuthClientMetadata(BaseModel):
4545
# supported auth methods for the token endpoint
4646
token_endpoint_auth_method: Literal[
4747
"none", "client_secret_post", "client_secret_basic", "private_key_jwt"
48-
] = "client_secret_post"
48+
] | None = None
4949
# supported grant_types of this implementation
5050
grant_types: list[
5151
Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"] | str

tests/client/test_auth.py

Lines changed: 156 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
Tests for refactored OAuth client authentication implementation.
33
"""
44

5+
import base64
6+
import json
57
import time
68
from unittest import mock
9+
from urllib.parse import unquote
710

811
import httpx
912
import pytest
1013
from inline_snapshot import Is, snapshot
1114
from pydantic import AnyHttpUrl, AnyUrl
1215

13-
from mcp.client.auth import OAuthClientProvider, PKCEParameters
16+
from mcp.client.auth import OAuthClientProvider, OAuthRegistrationError, PKCEParameters
1417
from mcp.client.auth.utils import (
1518
build_oauth_authorization_server_metadata_discovery_urls,
1619
build_protected_resource_metadata_discovery_urls,
@@ -21,7 +24,13 @@
2124
get_client_metadata_scopes,
2225
handle_registration_response,
2326
)
24-
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata
27+
from mcp.shared.auth import (
28+
OAuthClientInformationFull,
29+
OAuthClientMetadata,
30+
OAuthMetadata,
31+
OAuthToken,
32+
ProtectedResourceMetadata,
33+
)
2534

2635

2736
class MockTokenStorage:
@@ -592,6 +601,41 @@ async def test_register_client_skip_if_registered(self, oauth_provider: OAuthCli
592601
request = await oauth_provider._register_client()
593602
assert request is None
594603

604+
@pytest.mark.anyio
605+
async def test_register_client_none_auth_method_with_server_metadata(self, oauth_provider: OAuthClientProvider):
606+
"""Test that token_endpoint_auth_method=None selects from server's supported methods."""
607+
# Set server metadata with specific supported methods
608+
oauth_provider.context.oauth_metadata = OAuthMetadata(
609+
issuer=AnyHttpUrl("https://auth.example.com"),
610+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
611+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
612+
token_endpoint_auth_methods_supported=["client_secret_post"],
613+
)
614+
# Ensure client_metadata has None for token_endpoint_auth_method is None
615+
616+
request = await oauth_provider._register_client()
617+
assert request is not None
618+
619+
body = json.loads(request.content)
620+
assert body["token_endpoint_auth_method"] == "client_secret_post"
621+
622+
@pytest.mark.anyio
623+
async def test_register_client_none_auth_method_no_compatible(self, oauth_provider: OAuthClientProvider):
624+
"""Test that registration raises error when no compatible auth methods."""
625+
# Set server metadata with unsupported methods only
626+
oauth_provider.context.oauth_metadata = OAuthMetadata(
627+
issuer=AnyHttpUrl("https://auth.example.com"),
628+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
629+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
630+
token_endpoint_auth_methods_supported=["private_key_jwt", "client_secret_jwt"],
631+
)
632+
633+
with pytest.raises(OAuthRegistrationError) as exc_info:
634+
await oauth_provider._register_client()
635+
636+
assert "No compatible authentication methods" in str(exc_info.value)
637+
assert "private_key_jwt" in str(exc_info.value)
638+
595639
@pytest.mark.anyio
596640
async def test_token_exchange_request_authorization_code(self, oauth_provider: OAuthClientProvider):
597641
"""Test token exchange request building."""
@@ -600,6 +644,7 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider: O
600644
client_id="test_client",
601645
client_secret="test_secret",
602646
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
647+
token_endpoint_auth_method="client_secret_post",
603648
)
604649

605650
request = await oauth_provider._exchange_token_authorization_code("test_auth_code", "test_verifier")
@@ -625,6 +670,7 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
625670
client_id="test_client",
626671
client_secret="test_secret",
627672
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
673+
token_endpoint_auth_method="client_secret_post",
628674
)
629675

630676
request = await oauth_provider._refresh_token()
@@ -640,6 +686,114 @@ async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider,
640686
assert "client_id=test_client" in content
641687
assert "client_secret=test_secret" in content
642688

689+
@pytest.mark.anyio
690+
async def test_basic_auth_token_exchange(self, oauth_provider: OAuthClientProvider):
691+
"""Test token exchange with client_secret_basic authentication."""
692+
# Set up OAuth metadata to support basic auth
693+
oauth_provider.context.oauth_metadata = OAuthMetadata(
694+
issuer=AnyHttpUrl("https://auth.example.com"),
695+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
696+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
697+
token_endpoint_auth_methods_supported=["client_secret_basic", "client_secret_post"],
698+
)
699+
700+
client_id_raw = "test@client" # Include special character to test URL encoding
701+
client_secret_raw = "test:secret" # Include colon to test URL encoding
702+
703+
oauth_provider.context.client_info = OAuthClientInformationFull(
704+
client_id=client_id_raw,
705+
client_secret=client_secret_raw,
706+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
707+
token_endpoint_auth_method="client_secret_basic",
708+
)
709+
710+
request = await oauth_provider._exchange_token("test_auth_code", "test_verifier")
711+
712+
# Should use basic auth (registered method)
713+
assert "Authorization" in request.headers
714+
assert request.headers["Authorization"].startswith("Basic ")
715+
716+
# Decode and verify credentials are properly URL-encoded
717+
encoded_creds = request.headers["Authorization"][6:] # Remove "Basic " prefix
718+
decoded = base64.b64decode(encoded_creds).decode()
719+
client_id, client_secret = decoded.split(":", 1)
720+
721+
# Check URL encoding was applied
722+
assert client_id == "test%40client" # @ should be encoded as %40
723+
assert client_secret == "test%3Asecret" # : should be encoded as %3A
724+
725+
# Verify decoded values match original
726+
assert unquote(client_id) == client_id_raw
727+
assert unquote(client_secret) == client_secret_raw
728+
729+
# client_secret should NOT be in body for basic auth
730+
content = request.content.decode()
731+
assert "client_secret=" not in content
732+
assert "client_id=test%40client" in content # client_id still in body
733+
734+
@pytest.mark.anyio
735+
async def test_basic_auth_refresh_token(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
736+
"""Test token refresh with client_secret_basic authentication."""
737+
oauth_provider.context.current_tokens = valid_tokens
738+
739+
# Set up OAuth metadata to only support basic auth
740+
oauth_provider.context.oauth_metadata = OAuthMetadata(
741+
issuer=AnyHttpUrl("https://auth.example.com"),
742+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
743+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
744+
token_endpoint_auth_methods_supported=["client_secret_basic"],
745+
)
746+
747+
client_id = "test_client"
748+
client_secret = "test_secret"
749+
oauth_provider.context.client_info = OAuthClientInformationFull(
750+
client_id=client_id,
751+
client_secret=client_secret,
752+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
753+
token_endpoint_auth_method="client_secret_basic",
754+
)
755+
756+
request = await oauth_provider._refresh_token()
757+
758+
assert "Authorization" in request.headers
759+
assert request.headers["Authorization"].startswith("Basic ")
760+
761+
encoded_creds = request.headers["Authorization"][6:]
762+
decoded = base64.b64decode(encoded_creds).decode()
763+
assert decoded == f"{client_id}:{client_secret}"
764+
765+
# client_secret should NOT be in body
766+
content = request.content.decode()
767+
assert "client_secret=" not in content
768+
769+
@pytest.mark.anyio
770+
async def test_none_auth_method(self, oauth_provider: OAuthClientProvider):
771+
"""Test 'none' authentication method (public client)."""
772+
oauth_provider.context.oauth_metadata = OAuthMetadata(
773+
issuer=AnyHttpUrl("https://auth.example.com"),
774+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
775+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
776+
token_endpoint_auth_methods_supported=["none"],
777+
)
778+
779+
client_id = "public_client"
780+
oauth_provider.context.client_info = OAuthClientInformationFull(
781+
client_id=client_id,
782+
client_secret=None, # No secret for public client
783+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
784+
token_endpoint_auth_method="none",
785+
)
786+
787+
request = await oauth_provider._exchange_token("test_auth_code", "test_verifier")
788+
789+
# Should NOT have Authorization header
790+
assert "Authorization" not in request.headers
791+
792+
# Should NOT have client_secret in body
793+
content = request.content.decode()
794+
assert "client_secret=" not in content
795+
assert "client_id=public_client" in content
796+
643797

644798
class TestProtectedResourceMetadata:
645799
"""Test protected resource handling."""

0 commit comments

Comments
 (0)