Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions MCPify/Core/Auth/DeviceCode/DeviceCodeAuthentication.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public class DeviceCodeAuthentication : IAuthenticationProvider
private readonly IMcpContextAccessor _mcpContextAccessor;
private readonly HttpClient _httpClient;
private readonly Func<string, string, Task> _userPrompt;
private readonly string? _resourceUrl; // RFC 8707 resource parameter
private const string _deviceCodeProviderName = "DeviceCode";

public DeviceCodeAuthentication(
Expand All @@ -25,7 +26,8 @@ public DeviceCodeAuthentication(
ISecureTokenStore secureTokenStore,
IMcpContextAccessor mcpContextAccessor,
Func<string, string, Task> userPrompt,
HttpClient? httpClient = null)
HttpClient? httpClient = null,
string? resourceUrl = null)
{
_clientId = clientId;
_deviceCodeEndpoint = deviceCodeEndpoint;
Expand All @@ -35,6 +37,7 @@ public DeviceCodeAuthentication(
_mcpContextAccessor = mcpContextAccessor;
_userPrompt = userPrompt;
_httpClient = httpClient ?? new HttpClient();
_resourceUrl = resourceUrl;
}

public async Task ApplyAsync(HttpRequestMessage request, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -73,11 +76,11 @@ public async Task ApplyAsync(HttpRequestMessage request, CancellationToken cance

private async Task<TokenData> PerformDeviceLoginAsync(CancellationToken cancellationToken)
{
var codeRequest = new FormUrlEncodedContent(new Dictionary<string, string>
{
{ "client_id", _clientId },
{ "scope", _scope }
});
var codeRequest = FormUrlEncoded.Create()
.Add("client_id", _clientId)
.Add("scope", _scope)
.AddIfNotEmpty("resource", _resourceUrl) // RFC 8707
.ToContent();

var codeResponse = await _httpClient.PostAsync(_deviceCodeEndpoint, codeRequest, cancellationToken);
codeResponse.EnsureSuccessStatusCode();
Expand All @@ -94,12 +97,12 @@ private async Task<TokenData> PerformDeviceLoginAsync(CancellationToken cancella
{
await Task.Delay(interval * 1000, cancellationToken);

var tokenRequest = new FormUrlEncodedContent(new Dictionary<string, string>
{
{ "grant_type", "urn:ietf:params:oauth:grant-type:device_code" },
{ "client_id", _clientId },
{ "device_code", codeData.device_code }
});
var tokenRequest = FormUrlEncoded.Create()
.Add("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
.Add("client_id", _clientId)
.Add("device_code", codeData.device_code)
.AddIfNotEmpty("resource", _resourceUrl) // RFC 8707
.ToContent();

var tokenResponse = await _httpClient.PostAsync(_tokenEndpoint, tokenRequest, cancellationToken);

Expand All @@ -125,12 +128,12 @@ private async Task<TokenData> PerformDeviceLoginAsync(CancellationToken cancella

private async Task<TokenData> RefreshTokenAsync(string refreshToken, CancellationToken cancellationToken)
{
var content = new FormUrlEncodedContent(new Dictionary<string, string>
{
{ "grant_type", "refresh_token" },
{ "client_id", _clientId },
{ "refresh_token", refreshToken }
});
var content = FormUrlEncoded.Create()
.Add("grant_type", "refresh_token")
.Add("client_id", _clientId)
.Add("refresh_token", refreshToken)
.AddIfNotEmpty("resource", _resourceUrl) // RFC 8707
.ToContent();

var response = await _httpClient.PostAsync(_tokenEndpoint, content, cancellationToken);
response.EnsureSuccessStatusCode();
Expand Down
28 changes: 28 additions & 0 deletions MCPify/Core/Auth/FormUrlEncoded.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
namespace MCPify.Core.Auth;

/// <summary>
/// Fluent helper for creating application/x-www-form-urlencoded POST content.
/// </summary>
internal class FormUrlEncoded
{
private readonly List<KeyValuePair<string, string>> _params = new();

public static FormUrlEncoded Create() => new();

public FormUrlEncoded Add(string key, string value)
{
_params.Add(new(key, value));
return this;
}

public FormUrlEncoded AddIfNotEmpty(string key, string? value)
{
if (!string.IsNullOrEmpty(value))
{
_params.Add(new(key, value));
}
return this;
}

public FormUrlEncodedContent ToContent() => new(_params);
}
16 changes: 16 additions & 0 deletions MCPify/Core/Auth/IAccessTokenValidator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
namespace MCPify.Core.Auth;

/// <summary>
/// Interface for validating access tokens.
/// </summary>
public interface IAccessTokenValidator
{
/// <summary>
/// Validates an access token and returns the validation result.
/// </summary>
/// <param name="token">The access token to validate.</param>
/// <param name="expectedAudience">Optional expected audience value. If null, audience validation is skipped.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>A <see cref="TokenValidationResult"/> containing the validation outcome and extracted claims.</returns>
Task<TokenValidationResult> ValidateAsync(string token, string? expectedAudience, CancellationToken cancellationToken = default);
}
203 changes: 203 additions & 0 deletions MCPify/Core/Auth/JwtAccessTokenValidator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
using System.Text;
using System.Text.Json;

namespace MCPify.Core.Auth;

/// <summary>
/// JWT access token validator that parses and validates JWT tokens without signature verification.
/// This is suitable for tokens that have already been cryptographically validated by the authorization server.
/// Performs expiration, audience, and scope claim extraction.
/// </summary>
public class JwtAccessTokenValidator : IAccessTokenValidator
{
private readonly TokenValidationOptions _options;
private static readonly string[] ScopeClaimNames = { "scope", "scp", "scopes" };

public JwtAccessTokenValidator(TokenValidationOptions options)
{
_options = options;
}

public Task<TokenValidationResult> ValidateAsync(string token, string? expectedAudience, CancellationToken cancellationToken = default)
{
try
{
var parts = token.Split('.');
if (parts.Length < 2)
{
return Task.FromResult(TokenValidationResult.Failure("invalid_token", "Token is not a valid JWT format"));
}

var payloadJson = Base64UrlDecode(parts[1]);
using var doc = JsonDocument.Parse(payloadJson);
var root = doc.RootElement;

// Extract claims
var subject = GetStringClaim(root, "sub");
var issuer = GetStringClaim(root, "iss");
var audiences = GetAudienceClaim(root);
var scopes = GetScopeClaim(root);
var expiresAt = GetExpirationClaim(root);

// Validate expiration
if (expiresAt.HasValue)
{
var now = DateTimeOffset.UtcNow;
if (expiresAt.Value.Add(_options.ClockSkew) < now)
{
return Task.FromResult(TokenValidationResult.Failure("invalid_token", "Token has expired"));
}
}

// Validate audience if requested
if (_options.ValidateAudience && !string.IsNullOrEmpty(expectedAudience))
{
if (audiences.Count == 0 || !audiences.Any(a => string.Equals(a, expectedAudience, StringComparison.OrdinalIgnoreCase)))
{
return Task.FromResult(TokenValidationResult.Failure("invalid_token", $"Token audience does not match expected value: {expectedAudience}"));
}
}

return Task.FromResult(TokenValidationResult.Success(
scopes: scopes,
subject: subject,
audiences: audiences,
issuer: issuer,
expiresAt: expiresAt
));
}
catch (JsonException)
{
return Task.FromResult(TokenValidationResult.Failure("invalid_token", "Token payload is not valid JSON"));
}
catch (FormatException)
{
return Task.FromResult(TokenValidationResult.Failure("invalid_token", "Token payload is not valid Base64URL"));
}
catch (Exception ex)
{
return Task.FromResult(TokenValidationResult.Failure("invalid_token", $"Token validation failed: {ex.Message}"));
}
}

private static string? GetStringClaim(JsonElement root, string claimName)
{
if (root.TryGetProperty(claimName, out var claim) && claim.ValueKind == JsonValueKind.String)
{
return claim.GetString();
}
return null;
}

private List<string> GetAudienceClaim(JsonElement root)
{
if (!root.TryGetProperty("aud", out var audClaim))
{
return new List<string>();
}

if (audClaim.ValueKind == JsonValueKind.String)
{
var value = audClaim.GetString();
return value != null ? new List<string> { value } : new List<string>();
}

if (audClaim.ValueKind == JsonValueKind.Array)
{
var audiences = new List<string>();
foreach (var item in audClaim.EnumerateArray())
{
if (item.ValueKind == JsonValueKind.String)
{
var value = item.GetString();
if (value != null)
{
audiences.Add(value);
}
}
}
return audiences;
}

return new List<string>();
}

private List<string> GetScopeClaim(JsonElement root)
{
// Try the configured claim name first, then fall back to common alternatives
var claimNamesToTry = new List<string> { _options.ScopeClaimName };
foreach (var name in ScopeClaimNames)
{
if (!claimNamesToTry.Contains(name, StringComparer.OrdinalIgnoreCase))
{
claimNamesToTry.Add(name);
}
}

foreach (var claimName in claimNamesToTry)
{
if (!root.TryGetProperty(claimName, out var scopeClaim))
{
continue;
}

if (scopeClaim.ValueKind == JsonValueKind.String)
{
var value = scopeClaim.GetString();
if (!string.IsNullOrEmpty(value))
{
// Scopes are space-separated per RFC 6749
return value.Split(' ', StringSplitOptions.RemoveEmptyEntries).ToList();
}
}

if (scopeClaim.ValueKind == JsonValueKind.Array)
{
var scopes = new List<string>();
foreach (var item in scopeClaim.EnumerateArray())
{
if (item.ValueKind == JsonValueKind.String)
{
var value = item.GetString();
if (!string.IsNullOrEmpty(value))
{
scopes.Add(value);
}
}
}
return scopes;
}
}

return new List<string>();
}

private static DateTimeOffset? GetExpirationClaim(JsonElement root)
{
if (!root.TryGetProperty("exp", out var expClaim))
{
return null;
}

if (expClaim.ValueKind == JsonValueKind.Number)
{
var unixTime = expClaim.GetInt64();
return DateTimeOffset.FromUnixTimeSeconds(unixTime);
}

return null;
}

private static byte[] Base64UrlDecode(string input)
{
var output = input.Replace('-', '+').Replace('_', '/');
switch (output.Length % 4)
{
case 0: break;
case 2: output += "=="; break;
case 3: output += "="; break;
default: throw new FormatException("Illegal base64url string!");
}
return Convert.FromBase64String(output);
}
}
21 changes: 11 additions & 10 deletions MCPify/Core/Auth/OAuth/ClientCredentialsAuthentication.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public class ClientCredentialsAuthentication : IAuthenticationProvider
private readonly ISecureTokenStore _secureTokenStore;
private readonly IMcpContextAccessor _mcpContextAccessor;
private readonly HttpClient _httpClient;
private readonly string? _resourceUrl; // RFC 8707 resource parameter
private const string _clientCredentialsProviderName = "ClientCredentials";

public ClientCredentialsAuthentication(
Expand All @@ -23,7 +24,8 @@ public ClientCredentialsAuthentication(
string scope,
ISecureTokenStore secureTokenStore,
IMcpContextAccessor mcpContextAccessor,
HttpClient? httpClient = null)
HttpClient? httpClient = null,
string? resourceUrl = null)
{
_clientId = clientId;
_clientSecret = clientSecret;
Expand All @@ -32,6 +34,7 @@ public ClientCredentialsAuthentication(
_secureTokenStore = secureTokenStore;
_mcpContextAccessor = mcpContextAccessor;
_httpClient = httpClient ?? new HttpClient();
_resourceUrl = resourceUrl;
}

public async Task ApplyAsync(HttpRequestMessage request, CancellationToken cancellationToken = default)
Expand All @@ -54,15 +57,13 @@ public async Task ApplyAsync(HttpRequestMessage request, CancellationToken cance

private async Task<TokenData> RequestTokenAsync(CancellationToken cancellationToken)
{
var form = new Dictionary<string, string>
{
{ "grant_type", "client_credentials" },
{ "client_id", _clientId },
{ "client_secret", _clientSecret },
{ "scope", _scope }
};

var content = new FormUrlEncodedContent(form);
var content = FormUrlEncoded.Create()
.Add("grant_type", "client_credentials")
.Add("client_id", _clientId)
.Add("client_secret", _clientSecret)
.Add("scope", _scope)
.AddIfNotEmpty("resource", _resourceUrl) // RFC 8707
.ToContent();
var response = await _httpClient.PostAsync(_tokenEndpoint, content, cancellationToken);
response.EnsureSuccessStatusCode();

Expand Down
Loading