-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Implement RFC 7523 JWT flows #1247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
873bbe0
7c65d76
d927bc0
20b5dfc
1a5f104
fe548e5
7e0dfaf
4ab922c
e3a2b6d
ed2a486
a8067e1
0be70c4
13b3478
aaf2cc7
efecc7d
f1d1591
06177d1
2a2f562
31eeb63
ac75345
c3c6725
4a4c007
6677894
fc8331c
39758e2
36532fd
ed23997
e10f7c9
bcc5b39
b90a6c2
ba3cd1e
27f38e2
866172c
cd8e9dc
50a38a6
90428fe
ca9230a
f2eec66
69eaff1
aa5d820
032ff50
9bae21d
1d73750
a0bb22a
5b1d45b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| """ | ||
| OAuth2 Authentication implementation for HTTPX. | ||
|
|
||
| Implements authorization code flow with PKCE and automatic token refresh. | ||
| """ | ||
|
|
||
| from mcp.client.auth.oauth2 import * # noqa: F403 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| import time | ||
| from collections.abc import Awaitable, Callable | ||
| from typing import Any | ||
| from uuid import uuid4 | ||
|
|
||
| import httpx | ||
| import jwt | ||
| from pydantic import BaseModel, Field | ||
|
|
||
| from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage | ||
| from mcp.shared.auth import OAuthClientMetadata | ||
|
|
||
|
|
||
| class JWTParameters(BaseModel): | ||
| """JWT parameters.""" | ||
|
|
||
| assertion: str | None = Field( | ||
| default=None, | ||
| description="JWT assertion for JWT authentication. " | ||
| "Will be used instead of generating a new assertion if provided.", | ||
| ) | ||
|
|
||
| issuer: str | None = Field(default=None, description="Issuer for JWT assertions.") | ||
| subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.") | ||
| audience: str | None = Field(default=None, description="Audience for JWT assertions.") | ||
| claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.") | ||
| jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.") | ||
| jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.") | ||
| jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.") | ||
|
|
||
| def to_assertion(self, with_audience_fallback: str | None = None) -> str: | ||
| if self.assertion is not None: | ||
| # Prebuilt JWT (e.g. acquired out-of-band) | ||
| assertion = self.assertion | ||
| else: | ||
| if not self.jwt_signing_key: | ||
| raise OAuthFlowError("Missing signing key for JWT bearer grant") | ||
| if not self.issuer: | ||
| raise OAuthFlowError("Missing issuer for JWT bearer grant") | ||
| if not self.subject: | ||
| raise OAuthFlowError("Missing subject for JWT bearer grant") | ||
|
|
||
| audience = self.audience if self.audience else with_audience_fallback | ||
| if not audience: | ||
| raise OAuthFlowError("Missing audience for JWT bearer grant") | ||
|
|
||
| now = int(time.time()) | ||
| claims: dict[str, Any] = { | ||
| "iss": self.issuer, | ||
| "sub": self.subject, | ||
| "aud": audience, | ||
| "exp": now + self.jwt_lifetime_seconds, | ||
| "iat": now, | ||
| "jti": str(uuid4()), | ||
| } | ||
| claims.update(self.claims or {}) | ||
|
|
||
| assertion = jwt.encode( | ||
| claims, | ||
| self.jwt_signing_key, | ||
| algorithm=self.jwt_signing_algorithm or "RS256", | ||
| ) | ||
| return assertion | ||
|
|
||
|
|
||
| class RFC7523OAuthClientProvider(OAuthClientProvider): | ||
| """OAuth client provider for RFC7532 clients.""" | ||
|
|
||
| jwt_parameters: JWTParameters | None = None | ||
|
|
||
| def __init__( | ||
| self, | ||
| server_url: str, | ||
| client_metadata: OAuthClientMetadata, | ||
| storage: TokenStorage, | ||
| redirect_handler: Callable[[str], Awaitable[None]] | None = None, | ||
| callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, | ||
| timeout: float = 300.0, | ||
| jwt_parameters: JWTParameters | None = None, | ||
| ) -> None: | ||
| super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout) | ||
| self.jwt_parameters = jwt_parameters | ||
|
|
||
| async def _exchange_token_authorization_code( | ||
| self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None | ||
| ) -> httpx.Request: | ||
| """Build token exchange request for authorization_code flow.""" | ||
| token_data = token_data or {} | ||
| if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt": | ||
| self._add_client_authentication_jwt(token_data=token_data) | ||
| return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data) | ||
|
|
||
| async def _perform_authorization(self) -> httpx.Request: | ||
| """Perform the authorization flow.""" | ||
| if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types: | ||
| token_request = await self._exchange_token_jwt_bearer() | ||
| return token_request | ||
| else: | ||
| return await super()._perform_authorization() | ||
|
|
||
| def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): | ||
| """Add JWT assertion for client authentication to token endpoint parameters.""" | ||
| if not self.jwt_parameters: | ||
| raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow") | ||
| if not self.context.oauth_metadata: | ||
| raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow") | ||
|
|
||
| # We need to set the audience to the issuer identifier of the authorization server | ||
| # https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523 | ||
| issuer = str(self.context.oauth_metadata.issuer) | ||
| assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer) | ||
|
|
||
| # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2 | ||
| token_data["client_assertion"] = assertion | ||
| token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" | ||
| # We need to set the audience to the resource server, the audience is difference from the one in claims | ||
| # it represents the resource server that will validate the token | ||
| token_data["audience"] = self.context.get_resource_url() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe you'll want to use the issuer URL here: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the comment here didn't match the required change; I've updated it in the latest commit - I believe the issuer = str(self.context.oauth_metadata.issuer)
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)This line is actually setting the token_data["audience"] = self.context.get_resource_url()as described in RFC 8693. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yep i had the wrong line. thanks! |
||
|
|
||
| async def _exchange_token_jwt_bearer(self) -> httpx.Request: | ||
| """Build token exchange request for JWT bearer grant.""" | ||
| if not self.context.client_info: | ||
| raise OAuthFlowError("Missing client info") | ||
| if not self.jwt_parameters: | ||
| raise OAuthFlowError("Missing JWT parameters") | ||
| if not self.context.oauth_metadata: | ||
| raise OAuthTokenError("Missing OAuth metadata") | ||
|
|
||
| # We need to set the audience to the issuer identifier of the authorization server | ||
| # https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523 | ||
| issuer = str(self.context.oauth_metadata.issuer) | ||
| assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer) | ||
|
|
||
| token_data = { | ||
| "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", | ||
| "assertion": assertion, | ||
| } | ||
|
|
||
| if self.context.should_include_resource_param(self.context.protocol_version): | ||
| token_data["resource"] = self.context.get_resource_url() | ||
|
|
||
| if self.context.client_metadata.scope: | ||
| token_data["scope"] = self.context.client_metadata.scope | ||
|
|
||
| token_url = self._get_token_endpoint() | ||
| return httpx.Request( | ||
| "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the right imports here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean all classes one by one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in #1532