Skip to content

Commit

Permalink
AMQP and AMQP+WS support for X.509 authentication (#624)
Browse files Browse the repository at this point in the history
* Add tests for workload client trust bundle and minor fixes

* Add support for HTTP and WS X.509 auth

* Address PR comments

* Fix to obtain the client certificate chain

* Address PR comments

* Address PR comments and bug fix

* AMQP and AMQP+WS support for X.509 authentication

* Add tests and cleanups

* Address PR comments

* Rename files per PR comments

* More renames

* Rename EdgeHubAmqpException to EdgeAmqpException

* Address PR comment and add tests accordingly
  • Loading branch information
mrohera authored and varunpuranik committed Dec 12, 2018
1 parent d8aa924 commit 875776c
Show file tree
Hide file tree
Showing 25 changed files with 759 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ public static class AmqpEventIds
public const int LinkHandler = EventIdStart + 700;
public const int AmqpWebSocketListener = EventIdStart + 800;
public const int ServerWebSocketTransport = EventIdStart + 900;
public const int X509PrinciparAuthenticator = EventIdStart + 1000;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,35 @@ public static AmqpException GetAmqpException(Exception ex)
}
}

// Convert exception to EdgeHubAmqpException
// TODO: Make sure EdgeHubAmqpException is thrown from the right places.
EdgeHubAmqpException edgeHubAmqpException = GetEdgeHubAmqpException(ex);
// Convert exception to EdgeAmqpException
// TODO: Make sure EdgeAmqpException is thrown from the right places.
EdgeAmqpException edgeHubAmqpException = GetEdgeHubAmqpException(ex);
Error amqpError = GenerateAmqpError(edgeHubAmqpException);
return new AmqpException(amqpError);
}

static EdgeHubAmqpException GetEdgeHubAmqpException(Exception exception)
static EdgeAmqpException GetEdgeHubAmqpException(Exception exception)
{
if (exception is EdgeHubAmqpException edgeHubAmqpException)
if (exception is EdgeAmqpException edgeHubAmqpException)
{
return edgeHubAmqpException;
}
else if (exception.UnwindAs<UnauthorizedAccessException>() != null)
{
return new EdgeHubAmqpException("Unauthorized access", ErrorCode.IotHubUnauthorizedAccess, exception);
return new EdgeAmqpException("Unauthorized access", ErrorCode.IotHubUnauthorizedAccess, exception);
}
else if (exception is EdgeHubMessageTooLargeException)
{
return new EdgeHubAmqpException(exception.Message, ErrorCode.MessageTooLarge);
return new EdgeAmqpException(exception.Message, ErrorCode.MessageTooLarge);
}
else if (exception is InvalidOperationException)
{
return new EdgeHubAmqpException("Invalid action performed", ErrorCode.InvalidOperation);
return new EdgeAmqpException("Invalid action performed", ErrorCode.InvalidOperation);
}
return new EdgeHubAmqpException("Encountered server error", ErrorCode.ServerError, exception);
return new EdgeAmqpException("Encountered server error", ErrorCode.ServerError, exception);
}

static Error GenerateAmqpError(EdgeHubAmqpException exception) => new Error
static Error GenerateAmqpError(EdgeAmqpException exception) => new Error
{
Description = JsonConvert.SerializeObject(exception.Message),
Condition = AmqpErrorMapper.GetErrorCondition(exception.ErrorCode),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
using Microsoft.Azure.Amqp.Transport;
using Microsoft.Azure.Devices.Edge.Hub.Amqp.Settings;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Azure.Devices.Edge.Util.Concurrency;
using Microsoft.Extensions.Logging;
Expand All @@ -27,26 +28,32 @@ public class AmqpProtocolHead : IProtocolHead
readonly IWebSocketListenerRegistry webSocketListenerRegistry;
readonly ConcurrentDictionary<uint, AmqpConnection> incomingConnectionMap;
readonly AsyncLock syncLock;
readonly IAuthenticator authenticator;
readonly IClientCredentialsFactory clientCredentialsFactory;

TransportListener amqpTransportListener;

public AmqpProtocolHead(
ITransportSettings transportSettings,
AmqpSettings amqpSettings,
ITransportListenerProvider transportListenerProvider,
IWebSocketListenerRegistry webSocketListenerRegistry)
IWebSocketListenerRegistry webSocketListenerRegistry,
IAuthenticator authenticator,
IClientCredentialsFactory clientCredentialsFactory)
{
this.syncLock = new AsyncLock();
this.transportSettings = Preconditions.CheckNotNull(transportSettings, nameof(transportSettings));
this.amqpSettings = Preconditions.CheckNotNull(amqpSettings, nameof(amqpSettings));
this.transportListenerProvider = Preconditions.CheckNotNull(transportListenerProvider);
this.webSocketListenerRegistry = Preconditions.CheckNotNull(webSocketListenerRegistry);
this.authenticator = Preconditions.CheckNotNull(authenticator, nameof(authenticator));
this.clientCredentialsFactory = Preconditions.CheckNotNull(clientCredentialsFactory, nameof(clientCredentialsFactory));

this.connectionSettings = new AmqpConnectionSettings
{
ContainerId = "DeviceGateway_" + Guid.NewGuid().ToString("N"),
HostName = transportSettings.HostName,
// 'IdleTimeOut' on connection settings will be used to close connection if server hasn't
// 'IdleTimeOut' on connection settings will be used to close connection if server hasn't
// received any packet for 'IdleTimeout'
// Open frame send to client will have the IdleTimeout set and the client will do heart beat
// every 'IdleTimeout * 7 / 8'
Expand All @@ -63,7 +70,7 @@ public async Task StartAsync()
{
Events.Starting();

var amqpWebSocketListener = new AmqpWebSocketListener();
var amqpWebSocketListener = new AmqpWebSocketListener(this.authenticator, this.clientCredentialsFactory);
// This transport settings object sets up a listener for TLS over TCP and a listener for WebSockets.
TransportListener[] listeners = { this.transportSettings.Settings.CreateListener(), amqpWebSocketListener };

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,48 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
using Microsoft.Extensions.Logging;
using Microsoft.Azure.Amqp.Transport;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;

class AmqpWebSocketListener : TransportListener, IWebSocketListener
{
public string SubProtocol => Constants.WebSocketSubProtocol;

public AmqpWebSocketListener()
readonly IAuthenticator authenticator;
readonly IClientCredentialsFactory clientCredentialsFactory;
public AmqpWebSocketListener(IAuthenticator authenticator,
IClientCredentialsFactory clientCredentialsFactory)
: base(Constants.WebSocketListenerName)
{
this.authenticator = Preconditions.CheckNotNull(authenticator, nameof(authenticator));
this.clientCredentialsFactory = Preconditions.CheckNotNull(clientCredentialsFactory, nameof(clientCredentialsFactory));
}

public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId)
public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId, X509Certificate2 clientCert, IList<X509Certificate2> clientCertChain)
{
try
{
var taskCompletion = new TaskCompletionSource<bool>();

string localEndpointValue = localEndPoint.Expect(() => new ArgumentNullException(nameof(localEndPoint))).ToString();
var transport = new ServerWebSocketTransport(webSocket, localEndpointValue, remoteEndPoint.ToString(), correlationId);
ServerWebSocketTransport transport;
if ((clientCert != null) && (clientCertChain != null))
{
transport = new ServerWebSocketTransport(webSocket,
localEndpointValue,
remoteEndPoint.ToString(),
correlationId,
clientCert,
clientCertChain,
this.authenticator,
this.clientCredentialsFactory);
}
else
{
transport = new ServerWebSocketTransport(webSocket,
localEndpointValue,
remoteEndPoint.ToString(),
correlationId);
}

transport.Open();

var args = new TransportAsyncCallbackArgs { Transport = transport, CompletedSynchronously = false };
Expand All @@ -52,8 +76,8 @@ public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPo
}
}

public Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId, X509Certificate2 clientCert, IList<X509Certificate2> clientCertChain)
=> this.ProcessWebSocketRequestAsync(webSocket, localEndPoint, remoteEndPoint, correlationId);
public async Task ProcessWebSocketRequestAsync(WebSocket webSocket, Option<EndPoint> localEndPoint, EndPoint remoteEndPoint, string correlationId)
=> await this.ProcessWebSocketRequestAsync(webSocket, localEndPoint, remoteEndPoint, correlationId, null, null);

protected override void OnListen()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
using System;
using Microsoft.Azure.Devices.Common.Exceptions;

public class EdgeHubAmqpException : Exception
public class EdgeAmqpException : Exception
{
public EdgeHubAmqpException(string message, ErrorCode errorCode)
public EdgeAmqpException(string message, ErrorCode errorCode)
: this(message, errorCode, null)
{ }

public EdgeHubAmqpException(string message, ErrorCode errorCode, Exception innerException)
public EdgeAmqpException(string message, ErrorCode errorCode, Exception innerException)
: base(message, innerException)
{
this.ErrorCode = errorCode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Extensions.Logging;

public class EdgeHubSaslPlainAuthenticator : ISaslPlainAuthenticator
public class EdgeSaslPlainAuthenticator : ISaslPlainAuthenticator
{
readonly IAuthenticator authenticator;
readonly IClientCredentialsFactory clientCredentialsFactory;
readonly string iotHubHostName;

public EdgeHubSaslPlainAuthenticator(IAuthenticator authenticator, IClientCredentialsFactory clientCredentialsFactory, string iotHubHostName)
public EdgeSaslPlainAuthenticator(IAuthenticator authenticator, IClientCredentialsFactory clientCredentialsFactory, string iotHubHostName)
{
this.clientCredentialsFactory = Preconditions.CheckNotNull(clientCredentialsFactory, nameof(clientCredentialsFactory));
this.authenticator = Preconditions.CheckNotNull(authenticator, nameof(authenticator));
Expand Down Expand Up @@ -63,7 +63,7 @@ public async Task<IPrincipal> AuthenticateAsync(string identity, string password

static class Events
{
static readonly ILogger Log = Logger.Factory.CreateLogger<EdgeHubSaslPlainAuthenticator>();
static readonly ILogger Log = Logger.Factory.CreateLogger<EdgeSaslPlainAuthenticator>();
const int IdStart = AmqpEventIds.SaslPlainAuthenticator;

enum EventIds
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using System.Collections.Generic;
using System.Linq;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Azure.Amqp.Transport;
using Microsoft.Azure.Amqp.X509;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;

public class EdgeTlsTransport : TlsTransport
{
readonly IClientCredentialsFactory clientCredentialsProvider;
readonly IAuthenticator authenticator;
private IList<X509Certificate2> remoteCertificateChain;

public EdgeTlsTransport(
TransportBase innerTransport,
TlsTransportSettings tlsSettings,
IAuthenticator authenticator,
IClientCredentialsFactory clientCredentialsProvider)
: base(innerTransport, tlsSettings)
{
this.clientCredentialsProvider = Preconditions.CheckNotNull(clientCredentialsProvider, nameof(clientCredentialsProvider));
this.authenticator = Preconditions.CheckNotNull(authenticator, nameof(authenticator));
this.remoteCertificateChain = null;
}

protected override X509Principal CreateX509Principal(X509Certificate2 certificate)
{
var principal = new EdgeX509Principal(new X509CertificateIdentity(certificate, true),
this.remoteCertificateChain,
this.authenticator,
this.clientCredentialsProvider);
// release chain elements from here since principal has this
this.remoteCertificateChain = null;
return principal;
}

protected override bool ValidateRemoteCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
{
// copy of the chain elements since they are destroyed after this method completes
this.remoteCertificateChain = chain == null ? new List<X509Certificate2>() :
chain.ChainElements.Cast<X509ChainElement>().Select(element => element.Certificate).ToList();
return base.ValidateRemoteCertificate(sender, certificate, chain, sslPolicyErrors);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using Microsoft.Azure.Amqp.Transport;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;

public class EdgeTlsTransportListener : TlsTransportListener
{
readonly IClientCredentialsFactory clientCredentialsProvider;
readonly IAuthenticator authenticator;

public EdgeTlsTransportListener(
TlsTransportSettings transportSettings,
IAuthenticator authenticator,
IClientCredentialsFactory clientCredentialsProvider)
: base(transportSettings)
{
this.clientCredentialsProvider = Preconditions.CheckNotNull(clientCredentialsProvider, nameof(clientCredentialsProvider));
this.authenticator = Preconditions.CheckNotNull(authenticator, nameof(authenticator));
}

protected override TlsTransport OnCreateTransport(TransportBase innerTransport, TlsTransportSettings tlsTransportSettings) =>
new EdgeTlsTransport(innerTransport, tlsTransportSettings, this.authenticator, this.clientCredentialsProvider);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Microsoft.Azure.Devices.Edge.Hub.Amqp
{
using System;
using Microsoft.Azure.Amqp.Transport;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;

public class EdgeTlsTransportSettings : TlsTransportSettings
{
readonly IClientCredentialsFactory clientCredentialsProvider;
readonly IAuthenticator authenticator;

public EdgeTlsTransportSettings(
TransportSettings innerSettings,
bool isInitiator,
IAuthenticator authenticator,
IClientCredentialsFactory clientCredentialsProvider)
: base(innerSettings, isInitiator)
{
this.clientCredentialsProvider = Preconditions.CheckNotNull(clientCredentialsProvider, nameof(clientCredentialsProvider));
this.authenticator = Preconditions.CheckNotNull(authenticator, nameof(authenticator));
}

public override TransportListener CreateListener()
{
if (this.Certificate == null)
{
throw new InvalidOperationException("Server certificate must be set");
}

return new EdgeTlsTransportListener(this, this.authenticator, this.clientCredentialsProvider);
}
}
}
Loading

0 comments on commit 875776c

Please sign in to comment.