Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,18 @@ public class MutualTlsOptions

/// <summary>
/// Specifies a separate domain to run the MTLS endpoints on.
/// If the string does not contain any dots, a subdomain is assumed - e.g. main domain: identityserver.local, MTLS domain: mtls.identityserver.local
/// If the string contains dots, a completely separate domain is assumend, e.g. main domain: identity.app.com, MTLS domain: mtls.app.com. In this case you must set a static issuer name on the options.
/// </summary>
/// <remarks>If the string does not contain any dots, it is treated as a
/// subdomain. For example, if the non-mTLS endpoints are hosted at
/// example.com, configuring this option with the value "mtls" means that
/// mtls is required for requests to mtls.example.com.
///
/// If the string contains dots, it is treated as a complete domain.
/// mTLS will be required for requests whose host name matches the
/// configured domain name completely, including the port number.
/// This allows for separate domains for the mTLS and non-mTLS endpoints.
/// For example, identity.example.com and mtls.example.com.
/// </remarks>
public string? DomainName { get; set; }

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Duende.IdentityServer.Extensions;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using static Duende.IdentityServer.IdentityServerConstants;

Expand Down Expand Up @@ -36,65 +37,116 @@ public MutualTlsEndpointMiddleware(RequestDelegate next, IdentityServerOptions o
_logger = logger;
}

/// <inheritdoc />
public async Task Invoke(HttpContext context, IAuthenticationSchemeProvider schemes)
internal enum MtlsEndpointType
{
None,
SeparateDomain,
Subdomain,
PathBased
}

internal MtlsEndpointType DetermineMtlsEndpointType(HttpContext context, out PathString? subPath)
{
if (_options.MutualTls.Enabled)
subPath = null;

if (!_options.MutualTls.Enabled)
{
// domain-based MTLS
if (_options.MutualTls.DomainName.IsPresent())
return MtlsEndpointType.None;
}

if (_options.MutualTls.DomainName.IsPresent())
{
if (_options.MutualTls.DomainName.Contains('.'))
{
// separate domain
if (_options.MutualTls.DomainName.Contains('.'))
var requestedHost = HostString.FromUriComponent(_options.MutualTls.DomainName);
// Separate domain
if (RequestedHostMatches(context.Request.Host, _options.MutualTls.DomainName))
{
if (context.Request.Host.Host.Equals(_options.MutualTls.DomainName,
StringComparison.OrdinalIgnoreCase))
{
var result = await TriggerCertificateAuthentication(context);
if (!result.Succeeded)
{
return;
}
}
_logger.LogDebug("Requiring mTLS because the request's domain matches the configured mTLS domain name.");
return MtlsEndpointType.SeparateDomain;
}
// sub-domain
else
}
else
{
// Subdomain
if (context.Request.Host.Host.StartsWith(_options.MutualTls.DomainName + ".", StringComparison.OrdinalIgnoreCase))
{
if (context.Request.Host.Host.StartsWith(_options.MutualTls.DomainName + ".", StringComparison.OrdinalIgnoreCase))
{
var result = await TriggerCertificateAuthentication(context);
if (!result.Succeeded)
{
return;
}
}
_logger.LogDebug("Requiring mTLS because the request's subdomain matches the configured mTLS domain name.");
return MtlsEndpointType.Subdomain;
}
}
// path based MTLS
else if (context.Request.Path.StartsWithSegments(ProtocolRoutePaths.MtlsPathPrefix.EnsureLeadingSlash(), out var subPath))

_logger.LogDebug("Not requiring mTLS because this request's domain does not match the configured mTLS domain name.");
return MtlsEndpointType.None;
}

// Check path-based MTLS
if (context.Request.Path.StartsWithSegments(
ProtocolRoutePaths.MtlsPathPrefix.EnsureLeadingSlash(), out var path))
{
_logger.LogDebug("Requiring mTLS because the request's path begins with the configured mTLS path prefix.");
subPath = path;
return MtlsEndpointType.PathBased;
}

return MtlsEndpointType.None;
}

/// <inheritdoc />
public async Task Invoke(HttpContext context, IAuthenticationSchemeProvider schemes)
{
var mtlsConfigurationStyle = DetermineMtlsEndpointType(context, out var subPath);

if (mtlsConfigurationStyle != MtlsEndpointType.None)
{
var result = await TriggerCertificateAuthentication(context);
if (!result.Succeeded)
{
var result = await TriggerCertificateAuthentication(context);
return;
}

if (result.Succeeded)
{
var path = ProtocolRoutePaths.ConnectPathPrefix +
subPath.ToString().EnsureLeadingSlash();
path = path.EnsureLeadingSlash();
// Additional processing for path-based MTLS
if (mtlsConfigurationStyle == MtlsEndpointType.PathBased && subPath.HasValue)
{
var path = ProtocolRoutePaths.ConnectPathPrefix + subPath.Value.ToString().EnsureLeadingSlash();
path = path.EnsureLeadingSlash();

_logger.LogDebug("Rewriting MTLS request from: {oldPath} to: {newPath}",
context.Request.Path.ToString(), path);
context.Request.Path = path;
}
else
{
return;
}
_logger.LogDebug("Rewriting MTLS request from: {oldPath} to: {newPath}",
context.Request.Path.ToString(), path);
context.Request.Path = path;
}
}

await _next(context);
}


private bool RequestedHostMatches(HostString requestHost, string configuredDomain)
{
// Parse the configured domain which might contain a port
string configuredHostname = configuredDomain;
int configuredPort = 443;

int colonIndex = configuredDomain.IndexOf(':');
if (colonIndex >= 0)
{
configuredHostname = configuredDomain.Substring(0, colonIndex);
if (int.TryParse(configuredDomain.Substring(colonIndex + 1), out int port))
{
configuredPort = port;
}
}

// Compare hostnames (case-insensitive)
if (!string.Equals(requestHost.Host, configuredHostname, StringComparison.OrdinalIgnoreCase))
{
return false;
}

var requestPort = requestHost.Port ?? 443;
return requestPort == configuredPort;
}

private async Task<AuthenticateResult> TriggerCertificateAuthentication(HttpContext context)
{
var x509AuthResult =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Copyright (c) Duende Software. All rights reserved.
// See LICENSE in the project root for license information.

using Duende.IdentityServer.Hosting;
using Duende.IdentityServer.Configuration;
using Microsoft.AspNetCore.Http;
using UnitTests.Common;
using System.Threading.Tasks;
using Xunit;

namespace UnitTests.Hosting;

public class MutualTlsEndpointMiddlewareTests
{
private readonly IdentityServerOptions _options;
private readonly MutualTlsEndpointMiddleware _middleware;
private readonly HttpContext _httpContext;

public MutualTlsEndpointMiddlewareTests()
{
var testLogger = TestLogger.Create<MutualTlsEndpointMiddleware>();
_options = TestIdentityServerOptions.Create();
_middleware = new MutualTlsEndpointMiddleware(
next: (ctx) => Task.CompletedTask,
options: _options,
logger: testLogger
);
_httpContext = new DefaultHttpContext();
}

[Fact]
internal void mtls_endpoint_type_when_mtls_disabled_should_be_none()
{
_options.MutualTls.Enabled = false;
var result = _middleware.DetermineMtlsEndpointType(_httpContext, out var subPath);
Assert.Equal(MutualTlsEndpointMiddleware.MtlsEndpointType.None, result);
Assert.Null(subPath);
}

[Theory]
[InlineData("mtls.example.com", "mtls.example.com", MutualTlsEndpointMiddleware.MtlsEndpointType.SeparateDomain)]
[InlineData("mTLS.example.com", "mtls.example.com", MutualTlsEndpointMiddleware.MtlsEndpointType.SeparateDomain)]
[InlineData("mtls.example.com", "mTLS.example.com", MutualTlsEndpointMiddleware.MtlsEndpointType.SeparateDomain)]
[InlineData("mtls.example.com:443", "mtls.example.com", MutualTlsEndpointMiddleware.MtlsEndpointType.SeparateDomain)]
[InlineData("mtls.example.com:5001", "mtls.example.com", MutualTlsEndpointMiddleware.MtlsEndpointType.None)]
[InlineData("mtls.example.com", "mtls.example.com:443", MutualTlsEndpointMiddleware.MtlsEndpointType.SeparateDomain)]
[InlineData("mtls.example.com:443", "mtls.example.com:443", MutualTlsEndpointMiddleware.MtlsEndpointType.SeparateDomain)]
[InlineData("mtls.example.com:5001", "mtls.example.com:443", MutualTlsEndpointMiddleware.MtlsEndpointType.None)]
[InlineData("mtls.example.com", "mtls.example.com:5001", MutualTlsEndpointMiddleware.MtlsEndpointType.None)]
[InlineData("mtls.example.com:443", "mtls.example.com:5001", MutualTlsEndpointMiddleware.MtlsEndpointType.None)]
[InlineData("mtls.example.com:5001", "mtls.example.com:5001", MutualTlsEndpointMiddleware.MtlsEndpointType.SeparateDomain)]
internal void mtls_endpoint_type_separate_domain_should_be_detected(string requestedHost, string configuredDomainName, MutualTlsEndpointMiddleware.MtlsEndpointType expectedType)
{
// Arrange
_options.MutualTls.Enabled = true;
_options.MutualTls.DomainName = configuredDomainName;
_httpContext.Request.Host = new HostString(requestedHost);

// Act
var result = _middleware.DetermineMtlsEndpointType(_httpContext, out var subPath);

// Assert
Assert.Equal(expectedType, result);
Assert.Null(subPath);
}

[Theory]
[InlineData("example.com", "mtls.example.com")]
[InlineData("example.com:443", "mtls.example.com")]
[InlineData("example.com:5001", "mtls.example.com")]
[InlineData("other.example.com", "mtls.example.com")]
[InlineData("other.example.com:443", "mtls.example.com")]
[InlineData("example.com", "mtls.example.com:443")]
[InlineData("example.com:443", "mtls.example.com:443")]
[InlineData("other.example.com", "mtls.example.com:5001")]
[InlineData("other.example.com:5001", "mtls.example.com:5001")]
internal void mtls_endpoint_type_separate_domain_should_not_match_different_domain(string requestedHost, string configuredDomainName)
{
// Arrange
_options.MutualTls.Enabled = true;
_options.MutualTls.DomainName = configuredDomainName;
_httpContext.Request.Host = new HostString(requestedHost);

// Act
var result = _middleware.DetermineMtlsEndpointType(_httpContext, out var subPath);

// Assert
Assert.Equal(MutualTlsEndpointMiddleware.MtlsEndpointType.None, result);
Assert.Null(subPath);
}

[Theory]
[InlineData("mtls.example.com", "mtls")]
[InlineData("mtls.example.com", "mTLS")]
[InlineData("mTLS.example.com", "mtls")]
[InlineData("mtls.example.com:443", "mtls")]
[InlineData("mtls.example.com:5001", "mtls")]
internal void mtls_endpoint_type_subdomain_should_be_detected(string requestedHost, string configuredDomainName)
{
// Arrange
_options.MutualTls.Enabled = true;
_options.MutualTls.DomainName = configuredDomainName;
_httpContext.Request.Host = new HostString(requestedHost);

// Act
var result = _middleware.DetermineMtlsEndpointType(_httpContext, out var subPath);

// Assert
Assert.Equal(MutualTlsEndpointMiddleware.MtlsEndpointType.Subdomain, result);
Assert.Null(subPath);
}

[Theory]
[InlineData("api.example.com", "mtls")]
[InlineData("api.example.com:443", "mtls")]
[InlineData("example.com", "mtls")]
[InlineData("example.com:5001", "mtls")]
internal void mtls_endpoint_type_subdomain_should_not_match_different_subdomain(string requestedHost, string configuredDomainName)
{
// Arrange
_options.MutualTls.Enabled = true;
_options.MutualTls.DomainName = configuredDomainName;
_httpContext.Request.Host = new HostString(requestedHost);

// Act
var result = _middleware.DetermineMtlsEndpointType(_httpContext, out var subPath);

// Assert
Assert.Equal(MutualTlsEndpointMiddleware.MtlsEndpointType.None, result);
Assert.Null(subPath);
}

[Theory]
[InlineData("/connect/mtls/token")]
[InlineData("/connect/mTLS/token")]
internal void mtls_endpoint_type_path_based_should_be_detected(string requestedPath)
{
// Arrange
_options.MutualTls.Enabled = true;
_httpContext.Request.Path = new PathString(requestedPath);

// Act
var result = _middleware.DetermineMtlsEndpointType(_httpContext, out var subPath);

// Assert
Assert.Equal(MutualTlsEndpointMiddleware.MtlsEndpointType.PathBased, result);
Assert.Equal("/token", subPath!.Value);
}

[Fact]
internal void mtls_endpoint_type_should_be_none_when_enabled_but_no_matching_configuration()
{
// Arrange
_options.MutualTls.Enabled = true;
_options.MutualTls.DomainName = "mtls.example.com";
_httpContext.Request.Host = new HostString("regular.example.com");
_httpContext.Request.Path = new PathString("/connect/token");

// Act
var result = _middleware.DetermineMtlsEndpointType(_httpContext, out var subPath);

// Assert
Assert.Equal(MutualTlsEndpointMiddleware.MtlsEndpointType.None, result);
Assert.Null(subPath);
}
}