Skip to content

feat: Add OAuth client-side middleware with token injection and auto-refresh #83

@guyernest

Description

@guyernest

Problem

The SDK has comprehensive server-side OAuth support (bearer token validation, scope middleware) but lacks client-side OAuth middleware for:

  • Automatic token injection into outgoing requests
  • 401/403 detection and automatic token refresh
  • Retry logic after re-authentication
  • Token lifecycle management (expiry tracking, proactive refresh)

This is critical for MCP clients connecting to OAuth-protected servers. The TypeScript SDK provides withOAuth for this use case.

Proposed Solution

Add client-side OAuth middleware that operates at the HTTP transport layer:

pub struct OAuthClientMiddleware {
    config: OAuthConfig,
    token_store: Arc<RwLock<TokenStore>>,
    auth_flow: Arc<dyn OAuthFlow>,
}

pub struct OAuthConfig {
    /// OAuth provider configuration
    pub provider: OAuthProvider,
    
    /// Token storage strategy
    pub token_storage: TokenStorageStrategy,
    
    /// Automatic refresh before expiry
    pub refresh_before_expiry: Duration,
    
    /// Retry after 401/403
    pub retry_on_auth_failure: bool,
    
    /// Maximum retry attempts
    pub max_retry_attempts: u32,
}

pub enum OAuthProvider {
    /// Standard OAuth2 with client credentials
    ClientCredentials {
        token_url: Url,
        client_id: String,
        client_secret: String,
        scopes: Vec<String>,
    },
    
    /// OAuth2 with refresh token
    RefreshToken {
        token_url: Url,
        client_id: String,
        client_secret: String,
        refresh_token: String,
    },
    
    /// Custom token provider (for exotic flows)
    Custom(Arc<dyn TokenProvider>),
}

#[async_trait]
pub trait TokenProvider: Send + Sync {
    /// Obtain a fresh access token
    async fn get_token(&self) -> Result<AccessToken>;
    
    /// Refresh an expired token
    async fn refresh_token(&self, current: &AccessToken) -> Result<AccessToken>;
}

pub struct AccessToken {
    pub token: String,
    pub token_type: String,
    pub expires_at: Option<SystemTime>,
    pub scopes: Vec<String>,
}

Implementation as HttpMiddleware:

#[async_trait]
impl HttpMiddleware for OAuthClientMiddleware {
    fn priority(&self) -> MiddlewarePriority {
        MiddlewarePriority::High  // Run early to inject token
    }

    async fn on_request(
        &self,
        request: &mut http::Request<Vec<u8>>,
        _context: &HttpMiddlewareContext,
    ) -> Result<()> {
        // Check if token needs refresh (proactive)
        if self.should_refresh_token().await? {
            self.refresh_token().await?;
        }

        // Inject current token
        let token = self.token_store.read().await.current_token()?;
        request.headers_mut().insert(
            AUTHORIZATION,
            HeaderValue::from_str(&format!("{} {}", token.token_type, token.token))?,
        );
        
        Ok(())
    }

    async fn on_response(
        &self,
        response: &mut http::Response<Vec<u8>>,
        context: &HttpMiddlewareContext,
    ) -> Result<()> {
        // Detect auth failure
        if response.status() == StatusCode::UNAUTHORIZED 
            || response.status() == StatusCode::FORBIDDEN 
        {
            if self.config.retry_on_auth_failure && context.attempt < self.config.max_retry_attempts {
                // Refresh token
                self.refresh_token().await?;
                
                // Signal for retry (middleware chain will retry the request)
                context.set_metadata("should_retry", "true");
                return Err(Error::AuthenticationFailed("Token expired, refreshed".into()));
            }
        }
        
        Ok(())
    }
}

Use Cases

1. Client Credentials Flow

use pmcp::client::oauth::{OAuthClientMiddleware, OAuthConfig, OAuthProvider};

let oauth_config = OAuthConfig {
    provider: OAuthProvider::ClientCredentials {
        token_url: Url::parse("https://auth.example.com/token")?,
        client_id: "my-client".to_string(),
        client_secret: env::var("CLIENT_SECRET")?,
        scopes: vec!["mcp:read".to_string(), "mcp:tools".to_string()],
    },
    token_storage: TokenStorageStrategy::Memory,
    refresh_before_expiry: Duration::from_secs(60),
    retry_on_auth_failure: true,
    max_retry_attempts: 3,
};

let oauth_middleware = OAuthClientMiddleware::new(oauth_config).await?;

let transport = HttpTransport::builder()
    .url(server_url)
    .middleware(oauth_middleware)  // Auto token injection + refresh
    .build()?;

let mut client = Client::new(transport);
let info = client.initialize(ClientCapabilities::minimal()).await?;
// Token automatically injected, refreshed on 401, retried

2. Refresh Token Flow

let oauth_config = OAuthConfig {
    provider: OAuthProvider::RefreshToken {
        token_url: Url::parse("https://auth.example.com/token")?,
        client_id: "my-client".to_string(),
        client_secret: env::var("CLIENT_SECRET")?,
        refresh_token: load_refresh_token()?,
    },
    token_storage: TokenStorageStrategy::Persistent {
        path: PathBuf::from(".oauth_tokens"),
    },
    refresh_before_expiry: Duration::from_secs(300),
    retry_on_auth_failure: true,
    max_retry_attempts: 2,
};

let oauth_middleware = OAuthClientMiddleware::new(oauth_config).await?;

3. Custom Token Provider

struct CustomTokenProvider {
    // Your custom auth logic
}

#[async_trait]
impl TokenProvider for CustomTokenProvider {
    async fn get_token(&self) -> Result<AccessToken> {
        // Custom token acquisition (e.g., OIDC device flow)
        todo!()
    }
    
    async fn refresh_token(&self, current: &AccessToken) -> Result<AccessToken> {
        // Custom refresh logic
        todo!()
    }
}

let oauth_config = OAuthConfig {
    provider: OAuthProvider::Custom(Arc::new(CustomTokenProvider::new())),
    // ...
};

Implementation Plan

  1. Phase 1: Core OAuth types (Dependency: Issue feat: Add HttpMiddleware trait for transport-level HTTP concerns #82)

    • Define OAuthClientMiddleware trait
    • Add OAuthConfig, OAuthProvider, AccessToken
    • Implement TokenStore with memory/persistent storage
  2. Phase 2: Token providers

    • ClientCredentialsProvider - Client credentials flow
    • RefreshTokenProvider - Refresh token flow
    • TokenProvider trait for custom implementations
  3. Phase 3: Middleware implementation

    • Implement HttpMiddleware for OAuthClientMiddleware
    • Token injection in on_request
    • 401/403 detection and retry in on_response
    • Proactive token refresh logic
  4. Phase 4: Token lifecycle

    • Expiry tracking and proactive refresh
    • Token caching and persistence
    • Thread-safe token store
  5. Phase 5: Testing and examples

    • Unit tests for token flows
    • Integration tests with mock OAuth server
    • Example: examples/31_oauth_client.rs
    • Documentation in pmcp-book

Benefits

  • Automatic auth handling: No manual token management in user code
  • Resilient: Automatic refresh on 401/403 with retry
  • TypeScript parity: Matches TS SDK's withOAuth ergonomics
  • Flexible: Supports multiple OAuth flows + custom providers
  • Production-ready: Token persistence, proactive refresh, thread-safe

References

Related Issues


Priority: High (critical for production OAuth deployments)
Complexity: Medium-High
Dependencies: Issue #82 (HttpMiddleware)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions