Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Net.Http;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;

namespace Microsoft.Identity.Client
{
/// <summary>
/// Factory responsible for creating HttpClient with a custom server certificate validation callback.
/// This is useful for the Service Fabric scenario where the server certificate validation is required using the server cert.
/// See https://learn.microsoft.com/dotnet/api/system.net.http.httpclient?view=net-7.0#instancing for more details.
/// </summary>
/// <remarks>
/// Implementations must be thread safe.
/// Do not create a new HttpClient for each call to <see cref="GetHttpClient"/> - this leads to socket exhaustion.
/// If your app uses Integrated Windows Authentication, ensure <see cref="HttpClientHandler.UseDefaultCredentials"/> is set to true.
/// </remarks>
public interface IMsalSFHttpClientFactory : IMsalHttpClientFactory
{

/// <summary>
/// Method returning an HTTP client that will be used to validate the server certificate through the provided callback.
/// This method is useful when custom certificate validation logic is required,
/// for the managed identity flow running on a service fabric cluster.
/// </summary>
/// <param name="validateServerCert">Callback to validate the server certificate for the Service Fabric.</param>
/// <returns>An HTTP client configured with the provided server certificate validation callback.</returns>
HttpClient GetHttpClient(Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCert);
}
}
42 changes: 29 additions & 13 deletions src/client/Microsoft.Identity.Client/Http/HttpManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -51,8 +52,8 @@ public async Task<HttpResponse> SendRequestAsync(
ILoggerAdapter logger,
bool doNotThrow,
X509Certificate2 bindingCertificate,
HttpClient customHttpClient,
CancellationToken cancellationToken,
Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCert,
CancellationToken cancellationToken,
int retryCount = 0)
{
Exception timeoutException = null;
Expand All @@ -76,8 +77,7 @@ public async Task<HttpResponse> SendRequestAsync(
clonedBody,
method,
bindingCertificate,
customHttpClient,
logger,
validateServerCert, logger,
cancellationToken).ConfigureAwait(false);
}

Expand Down Expand Up @@ -113,9 +113,8 @@ public async Task<HttpResponse> SendRequestAsync(
logger,
doNotThrow,
bindingCertificate,
customHttpClient,
cancellationToken: cancellationToken,
retryCount) // Pass the updated retry count
validateServerCert, cancellationToken: cancellationToken,
retryCount: retryCount) // Pass the updated retry count
.ConfigureAwait(false);
}

Expand Down Expand Up @@ -146,15 +145,32 @@ public async Task<HttpResponse> SendRequestAsync(
return response;
}

private HttpClient GetHttpClient(X509Certificate2 x509Certificate2, HttpClient customHttpClient) {
if (x509Certificate2 != null && customHttpClient != null)
private HttpClient GetHttpClient(X509Certificate2 x509Certificate2, Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCert)
{
if (x509Certificate2 != null && validateServerCert != null)
{
throw new NotImplementedException("Mtls certificate cannot be used with service fabric. A custom http client is used for service fabric managed identity to validate the server certificate.");
}

if (customHttpClient != null)
if (validateServerCert != null)
{
return customHttpClient;
// If the factory is an IMsalSFHttpClientFactory, use it to get an HttpClient with the custom handler
// that validates the server certificate.
if (_httpClientFactory is IMsalSFHttpClientFactory msalSFHttpClientFactory)
{
return msalSFHttpClientFactory.GetHttpClient(validateServerCert);
}

#if NET471_OR_GREATER || NETSTANDARD || NET
// If the factory is not an IMsalSFHttpClientFactory, use it to get a default HttpClient
return new HttpClient(new HttpClientHandler()
{

ServerCertificateCustomValidationCallback = validateServerCert
});
#else
return _httpClientFactory.GetHttpClient();
#endif
}

if (_httpClientFactory is IMsalMtlsHttpClientFactory msalMtlsHttpClientFactory)
Expand Down Expand Up @@ -188,7 +204,7 @@ private async Task<HttpResponse> ExecuteAsync(
HttpContent body,
HttpMethod method,
X509Certificate2 bindingCertificate,
HttpClient customHttpClient,
Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCert,
ILoggerAdapter logger,
CancellationToken cancellationToken = default)
{
Expand All @@ -203,7 +219,7 @@ private async Task<HttpResponse> ExecuteAsync(

Stopwatch sw = Stopwatch.StartNew();

HttpClient client = GetHttpClient(bindingCertificate, customHttpClient);
HttpClient client = GetHttpClient(bindingCertificate, validateServerCert);

using (HttpResponseMessage responseMessage =
await client.SendAsync(requestMessage, cancellationToken).ConfigureAwait(false))
Expand Down
7 changes: 3 additions & 4 deletions src/client/Microsoft.Identity.Client/Http/IHttpManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;

namespace Microsoft.Identity.Client.Http
{
Expand All @@ -26,8 +26,7 @@ internal interface IHttpManager
/// <param name="logger">Logger from the request context.</param>
/// <param name="doNotThrow">Flag to decide if MsalServiceException is thrown or the response is returned in case of 5xx errors.</param>
/// <param name="mtlsCertificate">Certificate used for MTLS authentication.</param>
/// <param name="customHttpClient">Custom http client which bypasses the HttpClientFactory.
/// This is needed for service fabric managed identity where a cert validation callback is added to the handler.</param>
/// <param name="validateServerCertificate">Callback to validate the server cert for service fabric managed identity flow.</param>
/// <param name="cancellationToken"></param>
/// <param name="retryCount">Number of retries to be attempted in case of retriable status codes.</param>
/// <returns></returns>
Expand All @@ -39,7 +38,7 @@ Task<HttpResponse> SendRequestAsync(
ILoggerAdapter logger,
bool doNotThrow,
X509Certificate2 mtlsCertificate,
HttpClient customHttpClient,
Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> validateServerCertificate,
CancellationToken cancellationToken,
int retryCount = 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,14 @@ private async Task<RegionInfo> DiscoverAsync(ILoggerAdapter logger, Cancellation
Uri imdsUri = BuildImdsUri(DefaultApiVersion);

HttpResponse response = await _httpManager.SendRequestAsync(
imdsUri,
headers,
body: null,
HttpMethod.Get,
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
GetCancellationToken(requestCancellationToken))
imdsUri,
headers,
body: null,
method: HttpMethod.Get,
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
validateServerCertificate: null, cancellationToken: GetCancellationToken(requestCancellationToken))
.ConfigureAwait(false);

// A bad request occurs when the version in the IMDS call is no longer supported.
Expand All @@ -219,12 +218,11 @@ private async Task<RegionInfo> DiscoverAsync(ILoggerAdapter logger, Cancellation
imdsUri,
headers,
body: null,
HttpMethod.Get,
method: HttpMethod.Get,
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
GetCancellationToken(requestCancellationToken))
validateServerCertificate: null, cancellationToken: GetCancellationToken(requestCancellationToken))
.ConfigureAwait(false); // Call again with updated version
}

Expand Down Expand Up @@ -318,16 +316,16 @@ private async Task<string> GetImdsUriApiVersionAsync(ILoggerAdapter logger, Dict
Uri imdsErrorUri = new(ImdsEndpoint);

HttpResponse response = await _httpManager.SendRequestAsync(
imdsErrorUri,
headers,
body: null,
HttpMethod.Get,
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
GetCancellationToken(userCancellationToken))
.ConfigureAwait(false);
imdsErrorUri,
headers,
body: null,
method: HttpMethod.Get,
logger: logger,
doNotThrow: false,
mtlsCertificate: null,
validateServerCertificate: null,
cancellationToken: GetCancellationToken(userCancellationToken))
.ConfigureAwait(false);

// When IMDS endpoint is called without the api version query param, bad request response comes back with latest version.
if (response.StatusCode == HttpStatusCode.BadRequest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Globalization;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
Expand Down Expand Up @@ -33,12 +32,11 @@ public async Task ValidateAuthorityAsync(
new Uri(webFingerUrl),
null,
body: null,
System.Net.Http.HttpMethod.Get,
method: System.Net.Http.HttpMethod.Get,
logger: _requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
_requestContext.UserCancellationToken)
validateServerCertificate: null, cancellationToken: _requestContext.UserCancellationToken)
.ConfigureAwait(false);

if (httpResponse.StatusCode != HttpStatusCode.OK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
request.ComputeUri(),
request.Headers,
body: null,
HttpMethod.Get,
method: HttpMethod.Get,
logger: _requestContext.Logger,
doNotThrow: true,
mtlsCertificate: null,
GetHttpClientWithSslValidation(_requestContext),
cancellationToken).ConfigureAwait(false);
validateServerCertificate: null, cancellationToken: cancellationToken).ConfigureAwait(false);
}
else
{
Expand All @@ -75,12 +74,11 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
request.ComputeUri(),
request.Headers,
body: new FormUrlEncodedContent(request.BodyParameters),
HttpMethod.Post,
method: HttpMethod.Post,
logger: _requestContext.Logger,
doNotThrow: true,
mtlsCertificate: null,
GetHttpClientWithSslValidation(_requestContext),
cancellationToken)
validateServerCertificate: null, cancellationToken: cancellationToken)
.ConfigureAwait(false);

}
Expand All @@ -94,10 +92,13 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
}
}

// This method is internal for testing purposes.
internal virtual HttpClient GetHttpClientWithSslValidation(RequestContext requestContext)
// This method is used to validate the server certificate.
// It is overridden in the Service Fabric managed identity source to validate the certificate thumbprint.
// The default implementation always returns true.
internal virtual bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate,
System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors)
{
return null;
return true;
}

protected virtual Task<ManagedIdentityResponse> HandleResponseAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
using System;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.PlatformsCommon.Shared;
using Microsoft.Identity.Client.Utils;

namespace Microsoft.Identity.Client.ManagedIdentity
{
Expand Down Expand Up @@ -127,16 +124,16 @@ protected override async Task<ManagedIdentityResponse> HandleResponseAsync(
request.Headers.Add("Authorization", authHeaderValue);

response = await _requestContext.ServiceBundle.HttpManager.SendRequestAsync(
request.ComputeUri(),
request.Headers,
body: null,
System.Net.Http.HttpMethod.Get,
logger: _requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
customHttpClient: null,
cancellationToken)
.ConfigureAwait(false);
request.ComputeUri(),
request.Headers,
body: null,
method: System.Net.Http.HttpMethod.Get,
logger: _requestContext.Logger,
doNotThrow: false,
mtlsCertificate: null,
validateServerCertificate: null,
cancellationToken: cancellationToken)
.ConfigureAwait(false);

return await base.HandleResponseAsync(parameters, response, cancellationToken).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Globalization;
using System.Net.Http;
using System.Net.Security;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;

Expand Down Expand Up @@ -42,42 +43,17 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)
return new ServiceFabricManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.IdentityHeader);
}

internal override HttpClient GetHttpClientWithSslValidation(RequestContext requestContext)
internal override bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate,
System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors)
{
if (_httpClientLazy == null)
if (sslPolicyErrors == SslPolicyErrors.None)
{
_httpClientLazy = new Lazy<HttpClient>(() =>
{
HttpClientHandler handler = CreateHandlerWithSslValidation(requestContext.Logger);
return new HttpClient(handler);
});
return true;
}

return _httpClientLazy.Value;
return string.Equals(certificate.GetCertHashString(), EnvironmentVariables.IdentityServerThumbprint, StringComparison.OrdinalIgnoreCase);
}

internal HttpClientHandler CreateHandlerWithSslValidation(ILoggerAdapter logger)
{
#if NET471_OR_GREATER || NETSTANDARD || NET
logger.Info(() => "[Managed Identity] Setting up server certificate validation callback.");
return new HttpClientHandler
{
ServerCertificateCustomValidationCallback = (message, certificate, chain, sslPolicyErrors) =>
{
if (sslPolicyErrors != System.Net.Security.SslPolicyErrors.None)
{
return 0 == string.Compare(certificate.Thumbprint, EnvironmentVariables.IdentityServerThumbprint, StringComparison.OrdinalIgnoreCase);
}
return true;
}
};
#else
logger.Warning("[Managed Identity] Server certificate validation callback is not supported on .NET Framework.");
return new HttpClientHandler();
#endif
}


private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri endpoint, string identityHeaderValue) :
base(requestContext, ManagedIdentitySource.ServiceFabric)
{
Expand Down
Loading
Loading