Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,10 @@ public sealed class ClientOAuthOptions
/// </para>
/// </remarks>
public IDictionary<string, string> AdditionalAuthorizationParameters { get; set; } = new Dictionary<string, string>();

/// <summary>
/// Gets or sets the token cache to use for storing and retrieving tokens beyond the lifetime of the transport.
/// If none is provided, tokens will be cached with the transport.
/// </summary>
public ITokenCache? TokenCache { get; set; }
}
103 changes: 66 additions & 37 deletions src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ internal sealed partial class ClientOAuthProvider
/// </summary>
private const string BearerScheme = "Bearer";

private static readonly string[] s_wellKnownPaths = [".well-known/openid-configuration", ".well-known/oauth-authorization-server"];

private readonly Uri _serverUrl;
private readonly Uri _redirectUri;
private readonly string[]? _scopes;
Expand All @@ -43,7 +45,7 @@ internal sealed partial class ClientOAuthProvider
private string? _clientId;
private string? _clientSecret;

private TokenContainer? _token;
private ITokenCache _tokenCache;
private AuthorizationServerMetadata? _authServerMetadata;

/// <summary>
Expand All @@ -57,11 +59,11 @@ internal sealed partial class ClientOAuthProvider
public ClientOAuthProvider(
Uri serverUrl,
ClientOAuthOptions options,
HttpClient? httpClient = null,
HttpClient httpClient,
ILoggerFactory? loggerFactory = null)
{
_serverUrl = serverUrl ?? throw new ArgumentNullException(nameof(serverUrl));
_httpClient = httpClient ?? new HttpClient();
_httpClient = httpClient;
_logger = (ILogger?)loggerFactory?.CreateLogger<ClientOAuthProvider>() ?? NullLogger.Instance;

if (options is null)
Expand All @@ -85,6 +87,7 @@ public ClientOAuthProvider(
_dcrClientUri = options.DynamicClientRegistration?.ClientUri;
_dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken;
_dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate;
_tokenCache = options.TokenCache ?? new InMemoryTokenCache();
}

/// <summary>
Expand Down Expand Up @@ -138,20 +141,21 @@ public ClientOAuthProvider(
{
ThrowIfNotBearerScheme(scheme);

var tokens = await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false);

// Return the token if it's valid
if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5))
if (tokens is not null && !tokens.IsExpired)
{
return _token.AccessToken;
return tokens.AccessToken;
}

// Try to refresh the token if we have a refresh token
if (_token?.RefreshToken != null && _authServerMetadata != null)
// Try to refresh the access token if it is invalid and we have a refresh token.
if (tokens?.RefreshToken != null && _authServerMetadata != null)
{
var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false);
if (newToken != null)
var newTokens = await RefreshTokenAsync(tokens.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false);
if (newTokens is not null)
{
_token = newToken;
return _token.AccessToken;
return newTokens.AccessToken;
}
}

Expand Down Expand Up @@ -223,26 +227,29 @@ private async Task PerformOAuthAuthorizationAsync(
// Store auth server metadata for future refresh operations
_authServerMetadata = authServerMetadata;

// The existing access token must be invalid to have resulted in a 401 response, but refresh might still work.
if (await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false) is { RefreshToken: {} refreshToken })
{
var refreshedTokens = await RefreshTokenAsync(refreshToken, protectedResourceMetadata.Resource, authServerMetadata, cancellationToken).ConfigureAwait(false);
if (refreshedTokens is not null)
{
// A non-null result indicates the refresh succeeded and the new tokens have been stored.
return;
}
}

// Perform dynamic client registration if needed
if (string.IsNullOrEmpty(_clientId))
{
await PerformDynamicClientRegistrationAsync(authServerMetadata, cancellationToken).ConfigureAwait(false);
}

// Perform the OAuth flow
var token = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false);
await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false);

if (token is null)
{
ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token.");
}

_token = token;
LogOAuthAuthorizationCompleted();
}

private static readonly string[] s_wellKnownPaths = [".well-known/openid-configuration", ".well-known/oauth-authorization-server"];

private async Task<AuthorizationServerMetadata> GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken)
{
if (authServerUri.OriginalString.Length == 0 ||
Expand Down Expand Up @@ -298,7 +305,7 @@ private async Task<AuthorizationServerMetadata> GetAuthServerMetadataAsync(Uri a
throw new McpException($"Failed to find .well-known/openid-configuration or .well-known/oauth-authorization-server metadata for authorization server: '{authServerUri}'");
}

private async Task<TokenContainer> RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken)
private async Task<TokenContainer?> RefreshTokenAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken)
{
var requestContent = new FormUrlEncodedContent(new Dictionary<string, string>
{
Expand All @@ -314,10 +321,17 @@ private async Task<TokenContainer> RefreshTokenAsync(string refreshToken, Uri re
Content = requestContent
};

return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false);
using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);

if (!httpResponse.IsSuccessStatusCode)
{
return null;
}

return await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
}

private async Task<TokenContainer?> InitiateAuthorizationCodeFlowAsync(
private async Task InitiateAuthorizationCodeFlowAsync(
ProtectedResourceMetadata protectedResourceMetadata,
AuthorizationServerMetadata authServerMetadata,
CancellationToken cancellationToken)
Expand All @@ -330,10 +344,10 @@ private async Task<TokenContainer> RefreshTokenAsync(string refreshToken, Uri re

if (string.IsNullOrEmpty(authCode))
{
return null;
ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty authorization code.");
}

return await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false);
await ExchangeCodeForTokenAsync(protectedResourceMetadata, authServerMetadata, authCode!, codeVerifier, cancellationToken).ConfigureAwait(false);
}

private Uri BuildAuthorizationUrl(
Expand Down Expand Up @@ -377,7 +391,7 @@ private Uri BuildAuthorizationUrl(
return uriBuilder.Uri;
}

private async Task<TokenContainer> ExchangeCodeForTokenAsync(
private async Task ExchangeCodeForTokenAsync(
ProtectedResourceMetadata protectedResourceMetadata,
AuthorizationServerMetadata authServerMetadata,
string authorizationCode,
Expand All @@ -400,24 +414,39 @@ private async Task<TokenContainer> ExchangeCodeForTokenAsync(
Content = requestContent
};

return await FetchTokenAsync(request, cancellationToken).ConfigureAwait(false);
}

private async Task<TokenContainer> FetchTokenAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);
httpResponse.EnsureSuccessStatusCode();
await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
}

using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenContainer, cancellationToken).ConfigureAwait(false);
private async Task<TokenContainer> HandleSuccessfulTokenResponseAsync(HttpResponseMessage response, CancellationToken cancellationToken)
{
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenResponse, cancellationToken).ConfigureAwait(false);

if (tokenResponse is null)
{
ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response.");
ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{response.RequestMessage?.RequestUri}' returned an empty response.");
}

if (tokenResponse.TokenType is null || !string.Equals(tokenResponse.TokenType, BearerScheme, StringComparison.OrdinalIgnoreCase))
{
ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{response.RequestMessage?.RequestUri}' returned an unsupported token type: '{tokenResponse.TokenType ?? "<null>"}'. Only 'Bearer' tokens are supported.");
}

tokenResponse.ObtainedAt = DateTimeOffset.UtcNow;
return tokenResponse;
TokenContainer tokens = new()
{
AccessToken = tokenResponse.AccessToken,
RefreshToken = tokenResponse.RefreshToken,
ExpiresIn = tokenResponse.ExpiresIn,
TokenType = tokenResponse.TokenType,
Scope = tokenResponse.Scope,
ObtainedAt = DateTimeOffset.UtcNow,
};

await _tokenCache.StoreTokensAsync(tokens, cancellationToken).ConfigureAwait(false);

return tokens;
}

/// <summary>
Expand Down Expand Up @@ -581,7 +610,7 @@ private async Task<ProtectedResourceMetadata> ExtractProtectedResourceMetadata(H
string? resourceMetadataUrl = null;
foreach (var header in response.Headers.WwwAuthenticate)
{
if (string.Equals(header.Scheme, "Bearer", StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter))
if (string.Equals(header.Scheme, BearerScheme, StringComparison.OrdinalIgnoreCase) && !string.IsNullOrEmpty(header.Parameter))
{
resourceMetadataUrl = ParseWwwAuthenticateParameters(header.Parameter, "resource_metadata");
if (resourceMetadataUrl != null)
Expand Down
17 changes: 17 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/ITokenCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace ModelContextProtocol.Authentication;

/// <summary>
/// Allows the client to cache access tokens beyond the lifetime of the transport.
/// </summary>
public interface ITokenCache
{
/// <summary>
/// Cache the token. After a new access token is acquired, this method is invoked to store it.
/// </summary>
ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken);

/// <summary>
/// Get the cached token. This method is invoked for every request.
/// </summary>
ValueTask<TokenContainer?> GetTokensAsync(CancellationToken cancellationToken);
}
27 changes: 27 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

namespace ModelContextProtocol.Authentication;

/// <summary>
/// Caches the token in-memory within this instance.
/// </summary>
internal class InMemoryTokenCache : ITokenCache
{
private TokenContainer? _tokens;

/// <summary>
/// Cache the token.
/// </summary>
public ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken)
{
_tokens = tokens;
return default;
}

/// <summary>
/// Get the cached token.
/// </summary>
public ValueTask<TokenContainer?> GetTokensAsync(CancellationToken cancellationToken)
{
return new ValueTask<TokenContainer?>(_tokens);
}
}
42 changes: 12 additions & 30 deletions src/ModelContextProtocol.Core/Authentication/TokenContainer.cs
Original file line number Diff line number Diff line change
@@ -1,57 +1,39 @@
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Authentication;

/// <summary>
/// Represents a token response from the OAuth server.
/// Represents a cacheable combination of tokens ready to be used for authentication.
/// </summary>
internal sealed class TokenContainer
public sealed class TokenContainer
{
/// <summary>
/// Gets or sets the token type (typically "Bearer").
/// </summary>
public required string TokenType { get; set; }

/// <summary>
/// Gets or sets the access token.
/// </summary>
[JsonPropertyName("access_token")]
public string AccessToken { get; set; } = string.Empty;
public required string AccessToken { get; set; }

/// <summary>
/// Gets or sets the refresh token.
/// </summary>
[JsonPropertyName("refresh_token")]
public string? RefreshToken { get; set; }

/// <summary>
/// Gets or sets the number of seconds until the access token expires.
/// </summary>
[JsonPropertyName("expires_in")]
public int ExpiresIn { get; set; }

/// <summary>
/// Gets or sets the extended expiration time in seconds.
/// </summary>
[JsonPropertyName("ext_expires_in")]
public int ExtExpiresIn { get; set; }

/// <summary>
/// Gets or sets the token type (typically "Bearer").
/// </summary>
[JsonPropertyName("token_type")]
public string TokenType { get; set; } = string.Empty;
public int? ExpiresIn { get; set; }

/// <summary>
/// Gets or sets the scope of the access token.
/// </summary>
[JsonPropertyName("scope")]
public string Scope { get; set; } = string.Empty;
public string? Scope { get; set; }

/// <summary>
/// Gets or sets the timestamp when the token was obtained.
/// </summary>
[JsonIgnore]
public DateTimeOffset ObtainedAt { get; set; }
public required DateTimeOffset ObtainedAt { get; set; }

/// <summary>
/// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn.
/// </summary>
[JsonIgnore]
public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn);
internal bool IsExpired => ExpiresIn is not null && DateTimeOffset.UtcNow >= ObtainedAt.AddSeconds(ExpiresIn.Value);
}
39 changes: 39 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/TokenResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Authentication;

/// <summary>
/// Represents a token response from the OAuth server.
/// </summary>
internal sealed class TokenResponse
{
/// <summary>
/// Gets or sets the access token.
/// </summary>
[JsonPropertyName("access_token")]
public required string AccessToken { get; set; }

/// <summary>
/// Gets or sets the refresh token.
/// </summary>
[JsonPropertyName("refresh_token")]
public string? RefreshToken { get; set; }

/// <summary>
/// Gets or sets the number of seconds until the access token expires.
/// </summary>
[JsonPropertyName("expires_in")]
public int? ExpiresIn { get; set; }

/// <summary>
/// Gets or sets the token type (typically "Bearer").
/// </summary>
[JsonPropertyName("token_type")]
public required string TokenType { get; set; }

/// <summary>
/// Gets or sets the scope of the access token.
/// </summary>
[JsonPropertyName("scope")]
public string? Scope { get; set; }
}
2 changes: 1 addition & 1 deletion src/ModelContextProtocol.Core/McpJsonUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element)

[JsonSerializable(typeof(ProtectedResourceMetadata))]
[JsonSerializable(typeof(AuthorizationServerMetadata))]
[JsonSerializable(typeof(TokenContainer))]
[JsonSerializable(typeof(TokenResponse))]
[JsonSerializable(typeof(DynamicClientRegistrationRequest))]
[JsonSerializable(typeof(DynamicClientRegistrationResponse))]

Expand Down
Loading
Loading