Skip to content

Extend SSETransport to refresh bearer token #904

Closed
@JasperHG90

Description

@JasperHG90

Is your feature request related to a problem? Please describe.
I have deployed an MCP server to cloud run (GCP). I'm using the proxy approach to access it locally. The issue is that, when fetching an OIDC token, it is short-lived. As such, passing a static bearer token isn't an option.

Describe the solution you'd like
The simplest thing would be to extend the SSETransport class argument 'headers' to accept a callable / factory. I'd be happy to add a PR to change.

I'm currently using this:

import contextlib
import datetime
from collections.abc import AsyncIterator
from typing import Any, cast, Callable

from mcp import ClientSession
from mcp.client.sse import sse_client
from pydantic import AnyUrl
from typing_extensions import Unpack

from fastmcp.client.transports import ClientTransport, SessionKwargs
from fastmcp.server.dependencies import get_http_headers


class SSETransport(ClientTransport):
    """Transport implementation that connects to an MCP server via Server-Sent Events.

    This implementation allows the client to refresh the bearer token by a headers factory
    function."""

    def __init__(
        self,
        url: str | AnyUrl,
        headers: Callable[[], dict[str, str]] | dict[str, str] | None = None,
        sse_read_timeout: datetime.timedelta | float | int | None = None,
    ):
        if isinstance(url, AnyUrl):
            url = str(url)
        if not isinstance(url, str) or not url.startswith('http'):
            raise ValueError('Invalid HTTP/S URL provided for SSE.')
        self.url = url
        self.headers = headers or {}

        if isinstance(sse_read_timeout, int | float):
            sse_read_timeout = datetime.timedelta(seconds=sse_read_timeout)
        self.sse_read_timeout = sse_read_timeout

    @contextlib.asynccontextmanager
    async def connect_session(
        self, **session_kwargs: Unpack[SessionKwargs]
    ) -> AsyncIterator[ClientSession]:
        client_kwargs: dict[str, Any] = {}

        # load headers from an active HTTP request, if available. This will only be true
        # if the client is used in a FastMCP Proxy, in which case the MCP client headers
        # need to be forwarded to the remote server.
        if isinstance(self.headers, Callable):
            client_kwargs['headers'] = get_http_headers() | self.headers()
        else:
            client_kwargs['headers'] = get_http_headers() | self.headers

        # sse_read_timeout has a default value set, so we can't pass None without overriding it
        # instead we simply leave the kwarg out if it's not provided
        if self.sse_read_timeout is not None:
            client_kwargs['sse_read_timeout'] = self.sse_read_timeout.total_seconds()
        if session_kwargs.get('read_timeout_seconds', None) is not None:
            read_timeout_seconds = cast(
                datetime.timedelta, session_kwargs.get('read_timeout_seconds')
            )
            client_kwargs['timeout'] = read_timeout_seconds.total_seconds()

        async with sse_client(self.url, **client_kwargs) as transport:
            read_stream, write_stream = transport
            async with ClientSession(read_stream, write_stream, **session_kwargs) as session:
                yield session

    def __repr__(self) -> str:
        return f"<SSE(url='{self.url}')>"

Then, in mcp_proxy.py

import os

from fastmcp.client import Client
import httpx
from cachetools import TTLCache, cached
import google.auth
from fastmcp import FastMCP

# This SSETransport is customized to allow for short-lived
# id tokens, caching them, and refreshing them as needed.
from mcp_contract_registry.transports import SSETransport


@cached(TTLCache(maxsize=1, ttl=3599))
def get_id_token() -> str:
    url = 'https://oauth2.googleapis.com/token'
    creds = google.auth.default()[0]
    params = {
        'grant_type': 'refresh_token',
        'client_id': creds._client_id,  # type: ignore
        'client_secret': creds._client_secret,  # type: ignore
        'refresh_token': creds.refresh_token,  # type: ignore
    }
    request = httpx.post(
        url=url,
        params=params,
    )
    request.raise_for_status()
    id_token = request.json()['id_token']
    return id_token


def get_auth_header() -> dict[str, str]:
    """Get the Authorization header with the ID token."""
    return {'Authorization': f'Bearer {get_id_token()}'}


def entrypoint():
    client = Client(
        SSETransport(
            os.environ['MCP_CONTRACT_REGISTRY_URL'],
            headers=get_auth_header,
        )
    )
    proxy = FastMCP.as_proxy(backend=client)
    proxy.run()

Describe alternatives you've considered
N/A

Additional context
N/A

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions