Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ use std::{
use async_trait::async_trait;
use oauth2::{
AsyncHttpClient, AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
EmptyExtraTokenFields, HttpClientError, HttpRequest, HttpResponse, PkceCodeChallenge,
PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope, StandardTokenResponse,
TokenResponse, TokenUrl,
basic::{BasicClient, BasicTokenType},
EmptyExtraTokenFields, ExtraTokenFields, HttpClientError, HttpRequest, HttpResponse,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
StandardTokenResponse, TokenResponse, TokenUrl, basic::BasicTokenType,
};
use reqwest::{
Client as HttpClient, IntoUrl, StatusCode, Url,
header::{AUTHORIZATION, WWW_AUTHENTICATE},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, error, warn};
Expand Down Expand Up @@ -126,6 +126,32 @@ pub struct StoredAuthorizationState {
pub created_at: u64,
}

/// A transparent wrapper around a JSON object that captures any extra fields returned by the
/// authorization server during token exchange that are not part of the standard OAuth 2.0 token
/// response.
///
/// OAuth providers may include non-standard fields alongside the
/// standard OAuth fields. Those fields are collected here so callers
/// can inspect them without losing data.
///
/// The inner [`HashMap<String, Value>`] maps field names to their raw JSON values.
///
/// # Accessing extra fields
///
/// Extra fields are available through [`StandardTokenResponse::extra_fields()`], which returns a
/// reference to this struct. Use the inner map (`.0`) to look up individual fields by name:
///
/// ```rust,ignore
/// // Obtain the token response from the AuthorizationManager, then:
/// if let Some(value) = token_response.extra_fields().0.get("vendorSpecificField") {
/// println!("vendorSpecificField = {value}");
/// }
/// ```
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct VendorExtraTokenFields(pub HashMap<String, Value>);

impl ExtraTokenFields for VendorExtraTokenFields {}

impl StoredAuthorizationState {
pub fn new(pkce_verifier: &PkceCodeVerifier, csrf_token: &CsrfToken) -> Self {
Self {
Expand Down Expand Up @@ -345,7 +371,18 @@ pub struct OAuthClientConfig {

// add type aliases for oauth2 types
type OAuthErrorResponse = oauth2::StandardErrorResponse<oauth2::basic::BasicErrorResponseType>;
pub type OAuthTokenResponse = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;

/// The token response returned by the authorization server after a successful OAuth 2.0 flow.
///
/// This is a [`StandardTokenResponse`] parameterised with [`VendorExtraTokenFields`], which means
/// it carries both the standard OAuth fields and
/// any vendor-specific fields the server may have included in the JSON response body.
///
/// # Accessing vendor-specific fields
///
/// Call [`extra_fields()`][OAuthTokenResponse::extra_fields] to obtain a reference to the
/// [`VendorExtraTokenFields`] wrapper, then index into its inner map.
pub type OAuthTokenResponse = StandardTokenResponse<VendorExtraTokenFields, BasicTokenType>;
type OAuthTokenIntrospection =
oauth2::StandardTokenIntrospectionResponse<EmptyExtraTokenFields, BasicTokenType>;
type OAuthRevocableToken = oauth2::StandardRevocableToken;
Expand Down Expand Up @@ -581,7 +618,7 @@ impl AuthorizationManager {
let redirect_url = RedirectUrl::new(config.redirect_uri.clone())
.map_err(|e| AuthError::OAuthError(format!("Invalid re URL: {}", e)))?;

let mut client_builder = BasicClient::new(client_id.clone())
let mut client_builder: OAuthClient = oauth2::Client::new(client_id.clone())
.set_auth_uri(auth_url)
.set_token_uri(token_url)
.set_redirect_uri(redirect_url);
Expand Down Expand Up @@ -882,7 +919,7 @@ impl AuthorizationManager {
&self,
code: &str,
csrf_token: &str,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>, AuthError> {
) -> Result<OAuthTokenResponse, AuthError> {
debug!("start exchange code for token: {:?}", code);
let oauth_client = self
.oauth_client
Expand Down Expand Up @@ -1017,9 +1054,7 @@ impl AuthorizationManager {
}

/// refresh access token
pub async fn refresh_token(
&self,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>, AuthError> {
pub async fn refresh_token(&self) -> Result<OAuthTokenResponse, AuthError> {
let oauth_client = self
.oauth_client
.as_ref()
Expand Down Expand Up @@ -1551,7 +1586,7 @@ impl AuthorizationSession {
&self,
code: &str,
csrf_token: &str,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>, AuthError> {
) -> Result<OAuthTokenResponse, AuthError> {
self.auth_manager
.exchange_code_for_token(code, csrf_token)
.await
Expand Down Expand Up @@ -1876,6 +1911,7 @@ mod tests {
AuthError, AuthorizationManager, AuthorizationMetadata, InMemoryStateStore,
OAuthClientConfig, ScopeUpgradeConfig, StateStore, StoredAuthorizationState, is_https_url,
};
use crate::transport::auth::VendorExtraTokenFields;

// -- url helpers --

Expand Down Expand Up @@ -2686,11 +2722,13 @@ mod tests {
use super::{OAuthTokenResponse, StoredCredentials};

fn make_token_response(access_token: &str, expires_in_secs: Option<u64>) -> OAuthTokenResponse {
use oauth2::{AccessToken, EmptyExtraTokenFields, basic::BasicTokenType};
use oauth2::{AccessToken, basic::BasicTokenType};
let mut resp = OAuthTokenResponse::new(
AccessToken::new(access_token.to_string()),
BasicTokenType::Bearer,
EmptyExtraTokenFields {},
VendorExtraTokenFields {
..Default::default()
},
);
if let Some(secs) = expires_in_secs {
resp.set_expires_in(Some(&std::time::Duration::from_secs(secs)));
Expand Down