diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpWebSocketListener.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpWebSocketListener.cs index 5a33f11ab2a..9e18c89e200 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpWebSocketListener.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpWebSocketListener.cs @@ -29,7 +29,7 @@ public AmqpWebSocketListener( public string SubProtocol => Constants.WebSocketSubProtocol; - public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option localEndPoint, EndPoint remoteEndPoint, string correlationId, X509Certificate2 clientCert, IList clientCertChain) + public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option localEndPoint, EndPoint remoteEndPoint, string correlationId, X509Certificate2 clientCert, IList clientCertChain, IAuthenticator proxyAuthenticator = null) { try { @@ -46,7 +46,7 @@ public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option clientCertChain); + IList clientCertChain, + IAuthenticator proxyAuthenticator = null); } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/HttpProxiedCertificateExtractor.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/HttpProxiedCertificateExtractor.cs index a6645fba71f..db60178df49 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/HttpProxiedCertificateExtractor.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/HttpProxiedCertificateExtractor.cs @@ -107,12 +107,12 @@ enum EventIds public static void AuthenticationApiProxy(string remoteAddress) { - Log.LogInformation((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}"); + Log.LogDebug((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}"); } public static void AuthenticateApiProxySuccess() { - Log.LogInformation((int)EventIds.AuthenticationSuccess, $"Authentication attempt through ApiProxy success"); + Log.LogDebug((int)EventIds.AuthenticationSuccess, $"Authentication attempt through ApiProxy success"); } public static void AuthenticateApiProxyFailed(Exception ex) diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/controllers/HttpRequestAuthenticator.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/controllers/HttpRequestAuthenticator.cs index a9d7048df7e..47e70023797 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/controllers/HttpRequestAuthenticator.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/controllers/HttpRequestAuthenticator.cs @@ -134,7 +134,7 @@ public static void AuthenticationSucceeded(IIdentity identity) public static void AuthenticationApiProxy(string remoteAddress) { - Log.LogInformation((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}"); + Log.LogDebug((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}"); } } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/middleware/WebSocketHandlingMiddleware.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/middleware/WebSocketHandlingMiddleware.cs index c6f9b7381c5..e1ec0b8e5c7 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/middleware/WebSocketHandlingMiddleware.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Http/middleware/WebSocketHandlingMiddleware.cs @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Http.Middleware using System.Linq; using System.Net; using System.Net.WebSockets; + using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading.Tasks; @@ -78,17 +79,31 @@ async Task ProcessRequestAsync(HttpContext context, IWebSocketListener listener, var remoteEndPoint = new IPEndPoint(context.Connection.RemoteIpAddress, context.Connection.RemotePort); X509Certificate2 cert = await context.Connection.GetClientCertificateAsync(); + IAuthenticator proxyAuthenticator = null; if (cert == null) { - var certExtractor = await this.httpProxiedCertificateExtractorProvider; - cert = (await certExtractor.GetClientCertificate(context)).OrDefault(); + try + { + var certExtractor = await this.httpProxiedCertificateExtractorProvider; + // if not certificate in header it returns null, no api proxy authentication needed in this case + // if certificate was set in header it means it was forwarded by api proxy and authenticates api proxy by sas token + // and throws AuthenticationException if api proxy was not authenticated or returns the certificate if api proxy authentication succeeded + cert = (await certExtractor.GetClientCertificate(context)).OrDefault(); + } + catch (AuthenticationException ex) + { + Events.AuthenticationApiProxyFailed(remoteEndPoint.ToString(), ex); + // Set authenticator to unauthorize the call from subprotocol level (Mqtt or Amqp) + proxyAuthenticator = new NullAuthenticator(); + cert = context.GetForwardedCertificate(); + } } if (cert != null) { IList certChain = context.GetClientCertificateChain(); - await listener.ProcessWebSocketRequestAsync(webSocket, localEndPoint, remoteEndPoint, correlationId, cert, certChain); + await listener.ProcessWebSocketRequestAsync(webSocket, localEndPoint, remoteEndPoint, correlationId, cert, certChain, proxyAuthenticator); } else { @@ -129,7 +144,10 @@ public static void InvalidCertificate(Exception ex, string connectionIp) => Log.LogWarning((int)EventIds.InvalidCertificate, Invariant($"Invalid client certificate for incoming connection: {connectionIp}, Exception: {ex.Message}")); public static void AuthenticationApiProxy(string remoteAddress) => - Log.LogInformation((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}"); + Log.LogDebug((int)EventIds.AuthenticationApiProxy, $"Received authentication attempt through ApiProxy for {remoteAddress}"); + + public static void AuthenticationApiProxyFailed(string remoteAddress, Exception ex) => + Log.LogError((int)EventIds.AuthenticationApiProxy, $"Failed authentication attempt through ApiProxy for {remoteAddress}", ex); } } diff --git a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Mqtt/MqttWebSocketListener.cs b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Mqtt/MqttWebSocketListener.cs index 296fb2e9577..9ebf760e803 100644 --- a/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Mqtt/MqttWebSocketListener.cs +++ b/edge-hub/core/src/Microsoft.Azure.Devices.Edge.Hub.Mqtt/MqttWebSocketListener.cs @@ -75,9 +75,10 @@ public Task ProcessWebSocketRequestAsync( EndPoint remoteEndPoint, string correlationId, X509Certificate2 clientCert, - IList clientCertChain) + IList clientCertChain, + IAuthenticator proxyAuthenticator = null) { - var identityProvider = new DeviceIdentityProvider(this.authenticator, this.usernameParser, this.clientCredentialsFactory, this.metadataStore, this.clientCertAuthAllowed); + var identityProvider = new DeviceIdentityProvider(proxyAuthenticator ?? this.authenticator, this.usernameParser, this.clientCredentialsFactory, this.metadataStore, this.clientCertAuthAllowed); identityProvider.RegisterConnectionCertificate(clientCert, clientCertChain); return this.ProcessWebSocketRequestAsyncInternal(identityProvider, webSocket, localEndPoint, remoteEndPoint, correlationId); } diff --git a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Http.Test/WebSocketHandlingMiddlewareTest.cs b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Http.Test/WebSocketHandlingMiddlewareTest.cs index 73ca40b3d01..0e1c417035e 100644 --- a/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Http.Test/WebSocketHandlingMiddlewareTest.cs +++ b/edge-hub/core/test/Microsoft.Azure.Devices.Edge.Hub.Http.Test/WebSocketHandlingMiddlewareTest.cs @@ -5,15 +5,18 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Http.Test using System.Collections.Generic; using System.Net; using System.Net.WebSockets; + using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Hub.Http.Extensions; using Microsoft.Azure.Devices.Edge.Hub.Http.Middleware; using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Azure.Devices.Edge.Util.Test.Common; + using Microsoft.Extensions.Primitives; using Moq; using Xunit; @@ -121,6 +124,105 @@ public async Task SetsBadrequestWhenNoRegisteredListener() Assert.Equal((int)HttpStatusCode.BadRequest, httpContext.Response.StatusCode); } + [Fact] + public async Task UnauthorizedRequestWhenProxyAuthFails() + { + var next = Mock.Of(); + + var listener = new Mock(); + listener.Setup(wsl => wsl.SubProtocol).Returns("abc"); + listener.Setup( + wsl => wsl.ProcessWebSocketRequestAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.Is(auth => auth != null && auth.GetType() == typeof(NullAuthenticator)))) + .Returns(Task.CompletedTask); + + var registry = new WebSocketListenerRegistry(); + registry.TryRegister(listener.Object); + var certContentBytes = Util.Test.Common.CertificateHelper.GenerateSelfSignedCert($"test_cert").Export(X509ContentType.Cert); + string certContentBase64 = Convert.ToBase64String(certContentBytes); + string clientCertString = $"-----BEGIN CERTIFICATE-----\n{certContentBase64}\n-----END CERTIFICATE-----\n"; + clientCertString = WebUtility.UrlEncode(clientCertString); + HttpContext httpContext = this.ContextWithRequestedSubprotocolsAndForwardedCert(new StringValues(clientCertString), "abc"); + var certExtractor = new Mock(); + certExtractor.Setup(p => p.GetClientCertificate(It.IsAny())).ThrowsAsync(new AuthenticationException()); + + var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor.Object)); + await middleware.Invoke(httpContext); + + listener.VerifyAll(); + } + + [Fact] + public async Task AuthorizedRequestWhenProxyAuthSuccess() + { + var next = Mock.Of(); + + var listener = new Mock(); + listener.Setup(wsl => wsl.SubProtocol).Returns("abc"); + listener.Setup( + wsl => wsl.ProcessWebSocketRequestAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny>(), + It.Is(auth => auth == null))) + .Returns(Task.CompletedTask); + + var registry = new WebSocketListenerRegistry(); + registry.TryRegister(listener.Object); + var certContentBytes = Util.Test.Common.CertificateHelper.GenerateSelfSignedCert($"test_cert").Export(X509ContentType.Cert); + string certContentBase64 = Convert.ToBase64String(certContentBytes); + string clientCertString = $"-----BEGIN CERTIFICATE-----\n{certContentBase64}\n-----END CERTIFICATE-----\n"; + clientCertString = WebUtility.UrlEncode(clientCertString); + HttpContext httpContext = this.ContextWithRequestedSubprotocolsAndForwardedCert(new StringValues(clientCertString), "abc"); + var certExtractor = new Mock(); + certExtractor.Setup(p => p.GetClientCertificate(It.IsAny())).ReturnsAsync(Option.Some(new X509Certificate2(certContentBytes))); + + var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor.Object)); + await middleware.Invoke(httpContext); + + listener.VerifyAll(); + } + + [Fact] + public async Task AuthorizedRequestWhenCertIsNotSet() + { + var next = Mock.Of(); + + var listener = new Mock(); + listener.Setup(wsl => wsl.SubProtocol).Returns("abc"); + listener.Setup( + wsl => wsl.ProcessWebSocketRequestAsync( + It.IsAny(), + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(Task.CompletedTask); + + var registry = new WebSocketListenerRegistry(); + registry.TryRegister(listener.Object); + + HttpContext httpContext = this.ContextWithRequestedSubprotocols("abc"); + var authenticator = new Mock(); + authenticator.Setup(p => p.AuthenticateAsync(It.IsAny())).ReturnsAsync(false); + + IHttpProxiedCertificateExtractor certExtractor = new HttpProxiedCertificateExtractor(authenticator.Object, Mock.Of(), "hub", "edge", "proxy"); + + var middleware = new WebSocketHandlingMiddleware(next, registry, Task.FromResult(certExtractor)); + await middleware.Invoke(httpContext); + + authenticator.Verify(auth => auth.AuthenticateAsync(It.IsAny()), Times.Never); + listener.VerifyAll(); + } + static IWebSocketListenerRegistry ObservingWebSocketListenerRegistry(List correlationIds) { var registry = new Mock(); @@ -193,6 +295,28 @@ HttpContext ContextWithRequestedSubprotocols(params string[] subprotocols) && conn.ClientCertificate == new X509Certificate2())); } + HttpContext ContextWithRequestedSubprotocolsAndForwardedCert(StringValues cert, params string[] subprotocols) + { + return Mock.Of( + ctx => + ctx.WebSockets == Mock.Of( + wsm => + wsm.WebSocketRequestedProtocols == subprotocols + && wsm.IsWebSocketRequest + && wsm.AcceptWebSocketAsync(It.IsAny()) == Task.FromResult(Mock.Of())) + && ctx.Request == Mock.Of( + req => + req.Headers == Mock.Of(h => h.TryGetValue("x-ms-edge-clientcert", out cert)) == true ) + && ctx.Response == Mock.Of() + && ctx.Features == Mock.Of( + fc => fc.Get() == Mock.Of(f => f.ChainElements == new List())) + && ctx.Connection == Mock.Of( + conn => conn.LocalIpAddress == new IPAddress(123) + && conn.LocalPort == It.IsAny() + && conn.RemoteIpAddress == new IPAddress(123) && conn.RemotePort == It.IsAny() + && conn.ClientCertificate == new X509Certificate2())); + } + RequestDelegate ThrowingNextDelegate() { return ctx => throw new Exception("delegate 'next' should not be called");