Skip to content

Commit

Permalink
Fix ws auth with cert over ApiProxy (#4756) (#4767)
Browse files Browse the repository at this point in the history
For websockets connection the authentication needs to fail at subprotocol level to send the unauthorized to the client. 
It does the check in websockets middleware and if that fails it sets the Authenticator to NullAuthenticator. I think it is a little  weird, but otherwise would need to change the authenticator to receive the proxy token and do the proxy authentication there, better not to change it  now.
  • Loading branch information
ancaantochi authored Apr 2, 2021
1 parent fa60e52 commit 6c48961
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public AmqpWebSocketListener(

public string SubProtocol => Constants.WebSocketSubProtocol;

public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId, X509Certificate2 clientCert, IList<X509Certificate2> clientCertChain)
public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId, X509Certificate2 clientCert, IList<X509Certificate2> clientCertChain, IAuthenticator proxyAuthenticator = null)
{
try
{
Expand All @@ -46,7 +46,7 @@ public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPo
correlationId,
clientCert,
clientCertChain,
this.authenticator,
proxyAuthenticator ?? this.authenticator,
this.clientCredentialsFactory);
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Task ProcessWebSocketRequestAsync(
EndPoint remoteEndPoint,
string correlationId,
X509Certificate2 clientCert,
IList<X509Certificate2> clientCertChain);
IList<X509Certificate2> clientCertChain,
IAuthenticator proxyAuthenticator = null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ enum EventIds

public static void AuthenticationApiProxy(string remoteAddress)
{
Log.LogInformation((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}");
Log.LogDebug((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}");
}

public static void AuthenticateApiProxySuccess()
{
Log.LogInformation((int)EventIds.AuthenticationSuccess, $"Authentication attempt through ApiProxy success");
Log.LogDebug((int)EventIds.AuthenticationSuccess, $"Authentication attempt through ApiProxy success");
}

public static void AuthenticateApiProxyFailed(Exception ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public static void AuthenticationSucceeded(IIdentity identity)

public static void AuthenticationApiProxy(string remoteAddress)
{
Log.LogInformation((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}");
Log.LogDebug((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Http.Middleware
using System.Linq;
using System.Net;
using System.Net.WebSockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading.Tasks;
Expand Down Expand Up @@ -78,17 +79,31 @@ async Task ProcessRequestAsync(HttpContext context, IWebSocketListener listener,
var remoteEndPoint = new IPEndPoint(context.Connection.RemoteIpAddress, context.Connection.RemotePort);

X509Certificate2 cert = await context.Connection.GetClientCertificateAsync();
IAuthenticator proxyAuthenticator = null;

if (cert == null)
{
var certExtractor = await this.httpProxiedCertificateExtractorProvider;
cert = (await certExtractor.GetClientCertificate(context)).OrDefault();
try
{
var certExtractor = await this.httpProxiedCertificateExtractorProvider;
// if not certificate in header it returns null, no api proxy authentication needed in this case
// if certificate was set in header it means it was forwarded by api proxy and authenticates api proxy by sas token
// and throws AuthenticationException if api proxy was not authenticated or returns the certificate if api proxy authentication succeeded
cert = (await certExtractor.GetClientCertificate(context)).OrDefault();
}
catch (AuthenticationException ex)
{
Events.AuthenticationApiProxyFailed(remoteEndPoint.ToString(), ex);
// Set authenticator to unauthorize the call from subprotocol level (Mqtt or Amqp)
proxyAuthenticator = new NullAuthenticator();
cert = context.GetForwardedCertificate();
}
}

if (cert != null)
{
IList<X509Certificate2> certChain = context.GetClientCertificateChain();
await listener.ProcessWebSocketRequestAsync(webSocket, localEndPoint, remoteEndPoint, correlationId, cert, certChain);
await listener.ProcessWebSocketRequestAsync(webSocket, localEndPoint, remoteEndPoint, correlationId, cert, certChain, proxyAuthenticator);
}
else
{
Expand Down Expand Up @@ -129,7 +144,10 @@ public static void InvalidCertificate(Exception ex, string connectionIp) =>
Log.LogWarning((int)EventIds.InvalidCertificate, Invariant($"Invalid client certificate for incoming connection: {connectionIp}, Exception: {ex.Message}"));

public static void AuthenticationApiProxy(string remoteAddress) =>
Log.LogInformation((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}");
Log.LogDebug((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}");

public static void AuthenticationApiProxyFailed(string remoteAddress, Exception ex) =>
Log.LogError((int)EventIds.AuthenticationApiProxy, $"Failed authentication attempt through ApiProxy for {remoteAddress}", ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ public Task ProcessWebSocketRequestAsync(
EndPoint remoteEndPoint,
string correlationId,
X509Certificate2 clientCert,
IList<X509Certificate2> clientCertChain)
IList<X509Certificate2> clientCertChain,
IAuthenticator proxyAuthenticator = null)
{
var identityProvider = new DeviceIdentityProvider(this.authenticator, this.usernameParser, this.clientCredentialsFactory, this.metadataStore, this.clientCertAuthAllowed);
var identityProvider = new DeviceIdentityProvider(proxyAuthenticator ?? this.authenticator, this.usernameParser, this.clientCredentialsFactory, this.metadataStore, this.clientCertAuthAllowed);
identityProvider.RegisterConnectionCertificate(clientCert, clientCertChain);
return this.ProcessWebSocketRequestAsyncInternal(identityProvider, webSocket, localEndPoint, remoteEndPoint, correlationId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Http.Test
using System.Collections.Generic;
using System.Net;
using System.Net.WebSockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Hub.Http.Extensions;
using Microsoft.Azure.Devices.Edge.Hub.Http.Middleware;
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Azure.Devices.Edge.Util.Test.Common;
using Microsoft.Extensions.Primitives;
using Moq;
using Xunit;

Expand Down Expand Up @@ -121,6 +124,105 @@ public async Task SetsBadrequestWhenNoRegisteredListener()
Assert.Equal((int)HttpStatusCode.BadRequest, httpContext.Response.StatusCode);
}

[Fact]
public async Task UnauthorizedRequestWhenProxyAuthFails()
{
var next = Mock.Of<RequestDelegate>();

var listener = new Mock<IWebSocketListener>();
listener.Setup(wsl => wsl.SubProtocol).Returns("abc");
listener.Setup(
wsl => wsl.ProcessWebSocketRequestAsync(
It.IsAny<WebSocket>(),
It.IsAny<Option<EndPoint>>(),
It.IsAny<EndPoint>(),
It.IsAny<string>(),
It.IsAny<X509Certificate2>(),
It.IsAny<IList<X509Certificate2>>(),
It.Is<IAuthenticator>(auth => auth != null && auth.GetType() == typeof(NullAuthenticator))))
.Returns(Task.CompletedTask);

var registry = new WebSocketListenerRegistry();
registry.TryRegister(listener.Object);
var certContentBytes = Util.Test.Common.CertificateHelper.GenerateSelfSignedCert($"test_cert").Export(X509ContentType.Cert);
string certContentBase64 = Convert.ToBase64String(certContentBytes);
string clientCertString = $"-----BEGIN CERTIFICATE-----\n{certContentBase64}\n-----END CERTIFICATE-----\n";
clientCertString = WebUtility.UrlEncode(clientCertString);
HttpContext httpContext = this.ContextWithRequestedSubprotocolsAndForwardedCert(new StringValues(clientCertString), "abc");
var certExtractor = new Mock<IHttpProxiedCertificateExtractor>();
certExtractor.Setup(p => p.GetClientCertificate(It.IsAny<HttpContext>())).ThrowsAsync(new AuthenticationException());

var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor.Object));
await middleware.Invoke(httpContext);

listener.VerifyAll();
}

[Fact]
public async Task AuthorizedRequestWhenProxyAuthSuccess()
{
var next = Mock.Of<RequestDelegate>();

var listener = new Mock<IWebSocketListener>();
listener.Setup(wsl => wsl.SubProtocol).Returns("abc");
listener.Setup(
wsl => wsl.ProcessWebSocketRequestAsync(
It.IsAny<WebSocket>(),
It.IsAny<Option<EndPoint>>(),
It.IsAny<EndPoint>(),
It.IsAny<string>(),
It.IsAny<X509Certificate2>(),
It.IsAny<IList<X509Certificate2>>(),
It.Is<IAuthenticator>(auth => auth == null)))
.Returns(Task.CompletedTask);

var registry = new WebSocketListenerRegistry();
registry.TryRegister(listener.Object);
var certContentBytes = Util.Test.Common.CertificateHelper.GenerateSelfSignedCert($"test_cert").Export(X509ContentType.Cert);
string certContentBase64 = Convert.ToBase64String(certContentBytes);
string clientCertString = $"-----BEGIN CERTIFICATE-----\n{certContentBase64}\n-----END CERTIFICATE-----\n";
clientCertString = WebUtility.UrlEncode(clientCertString);
HttpContext httpContext = this.ContextWithRequestedSubprotocolsAndForwardedCert(new StringValues(clientCertString), "abc");
var certExtractor = new Mock<IHttpProxiedCertificateExtractor>();
certExtractor.Setup(p => p.GetClientCertificate(It.IsAny<HttpContext>())).ReturnsAsync(Option.Some(new X509Certificate2(certContentBytes)));

var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor.Object));
await middleware.Invoke(httpContext);

listener.VerifyAll();
}

[Fact]
public async Task AuthorizedRequestWhenCertIsNotSet()
{
var next = Mock.Of<RequestDelegate>();

var listener = new Mock<IWebSocketListener>();
listener.Setup(wsl => wsl.SubProtocol).Returns("abc");
listener.Setup(
wsl => wsl.ProcessWebSocketRequestAsync(
It.IsAny<WebSocket>(),
It.IsAny<Option<EndPoint>>(),
It.IsAny<EndPoint>(),
It.IsAny<string>()))
.Returns(Task.CompletedTask);

var registry = new WebSocketListenerRegistry();
registry.TryRegister(listener.Object);

HttpContext httpContext = this.ContextWithRequestedSubprotocols("abc");
var authenticator = new Mock<IAuthenticator>();
authenticator.Setup(p => p.AuthenticateAsync(It.IsAny<IClientCredentials>())).ReturnsAsync(false);

IHttpProxiedCertificateExtractor certExtractor = new HttpProxiedCertificateExtractor(authenticator.Object, Mock.Of<IClientCredentialsFactory>(), "hub", "edge", "proxy");

var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor));
await middleware.Invoke(httpContext);

authenticator.Verify(auth => auth.AuthenticateAsync(It.IsAny<IClientCredentials>()), Times.Never);
listener.VerifyAll();
}

static IWebSocketListenerRegistry ObservingWebSocketListenerRegistry(List<string> correlationIds)
{
var registry = new Mock<IWebSocketListenerRegistry>();
Expand Down Expand Up @@ -193,6 +295,28 @@ HttpContext ContextWithRequestedSubprotocols(params string[] subprotocols)
&& conn.ClientCertificate == new X509Certificate2()));
}

HttpContext ContextWithRequestedSubprotocolsAndForwardedCert(StringValues cert, params string[] subprotocols)
{
return Mock.Of<HttpContext>(
ctx =>
ctx.WebSockets == Mock.Of<WebSocketManager>(
wsm =>
wsm.WebSocketRequestedProtocols == subprotocols
&& wsm.IsWebSocketRequest
&& wsm.AcceptWebSocketAsync(It.IsAny<string>()) == Task.FromResult(Mock.Of<WebSocket>()))
&& ctx.Request == Mock.Of<HttpRequest>(
req =>
req.Headers == Mock.Of<IHeaderDictionary>(h => h.TryGetValue("x-ms-edge-clientcert", out cert)) == true )
&& ctx.Response == Mock.Of<HttpResponse>()
&& ctx.Features == Mock.Of<IFeatureCollection>(
fc => fc.Get<ITlsConnectionFeatureExtended>() == Mock.Of<ITlsConnectionFeatureExtended>(f => f.ChainElements == new List<X509Certificate2>()))
&& ctx.Connection == Mock.Of<ConnectionInfo>(
conn => conn.LocalIpAddress == new IPAddress(123)
&& conn.LocalPort == It.IsAny<int>()
&& conn.RemoteIpAddress == new IPAddress(123) && conn.RemotePort == It.IsAny<int>()
&& conn.ClientCertificate == new X509Certificate2()));
}

RequestDelegate ThrowingNextDelegate()
{
return ctx => throw new Exception("delegate 'next' should not be called");
Expand Down

0 comments on commit 6c48961

Please sign in to comment.