From 93be5343561362c6244e5e42e09b413baecd53c3 Mon Sep 17 00:00:00 2001 From: Varun Puranik Date: Sat, 1 Dec 2018 22:26:02 -0800 Subject: [PATCH] EdgeHub: Allow multiplexing client connections over AMQP (#587) * Add AMQP Downstream Multiplexing support * Amqp Mux changes * Fix link handlers * Cleanup * Get product code to build * Cleanup * Fix tests * Fix tests * Format and cleanup * Fix merge * fix inheritance * Update edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/CbsNode.cs Co-Authored-By: varunpuranik * Update edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/CbsNode.cs Co-Authored-By: varunpuranik * Remove commented members * Add C2D subscription if not module identity * Fix tests --- .../AmqpAuthentication.cs | 22 ---- .../AmqpRuntimeProvider.cs | 2 +- .../CbsNode.cs | 21 +-- ...nHandler.cs => ClientConnectionHandler.cs} | 124 ++++++------------ .../ClientConnectionsHandler.cs | 23 ++++ .../EdgeHubSaslPlainAuthenticator.cs | 2 +- .../IAmqpAuthenticator.cs | 1 + .../ICbsNode.cs | 5 +- .../IClientConnectionsHandler.cs | 11 ++ .../IConnectionHandler.cs | 2 - .../SaslPrincipal.cs | 22 +++- .../Templates.cs | 11 -- .../linkhandlers/DeviceBoundLinkHandler.cs | 14 +- .../linkhandlers/EventsLinkHandler.cs | 31 +++-- .../linkhandlers/IReceivingLinkHandler.cs | 3 +- .../linkhandlers/LinkHandler.cs | 73 ++++++----- .../linkhandlers/LinkHandlerProvider.cs | 112 ++++++++++------ .../MethodReceivingLinkHandler.cs | 10 +- .../linkhandlers/MethodSendingLinkHandler.cs | 10 +- .../linkhandlers/ModuleMessageLinkHandler.cs | 10 +- .../linkhandlers/ReceivingLinkHandler.cs | 22 +++- .../linkhandlers/SendingLinkHandler.cs | 18 ++- .../linkhandlers/TwinReceivingLinkHandler.cs | 9 +- .../linkhandlers/TwinSendingLinkHandler.cs | 10 +- .../DeviceScopeTokenAuthenticator.cs | 2 +- .../modules/AmqpModule.cs | 3 +- .../CbsNodeTest.cs | 97 +------------- .../ConnectionHandlerTest.cs | 103 +++------------ .../DeviceBoundLinkHandlerTest.cs | 28 ++-- .../EdgeHubSaslPlainAuthenticatorTest.cs | 12 +- .../EventsLinkHandlerTest.cs | 46 ++++--- .../LinkHandlerProviderTest.cs | 99 ++++++++++---- .../ReceivingLinkHandlerTest.cs | 26 +++- .../SaslPrincipalTest.cs | 7 +- .../SendingLinkHandlerTest.cs | 65 ++++++--- .../TwinReceivingLinkHandlerTest.cs | 31 +++-- 36 files changed, 562 insertions(+), 525 deletions(-) delete mode 100644 edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpAuthentication.cs rename edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/{ConnectionHandler.cs => ClientConnectionHandler.cs} (71%) create mode 100644 edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionsHandler.cs create mode 100644 edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IClientConnectionsHandler.cs diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpAuthentication.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpAuthentication.cs deleted file mode 100644 index bb251e43257..00000000000 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpAuthentication.cs +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -namespace Microsoft.Azure.Devices.Edge.Hub.Amqp -{ - using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; - using Microsoft.Azure.Devices.Edge.Util; - - public class AmqpAuthentication - { - public static AmqpAuthentication Unauthenticated = new AmqpAuthentication(false, Option.None()); - - public AmqpAuthentication(bool isAuthenticated, Option clientCredentials) - { - this.IsAuthenticated = isAuthenticated; - this.ClientCredentials = clientCredentials; - } - - public bool IsAuthenticated { get; } - - public Option ClientCredentials { get; } - } -} diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpRuntimeProvider.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpRuntimeProvider.cs index e4a3296d3fa..5361cf65c04 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpRuntimeProvider.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/AmqpRuntimeProvider.cs @@ -84,7 +84,7 @@ void OnConnectionOpening(object sender, OpenEventArgs e) amqpConnection.Extensions.Add(cbsNode); } - IConnectionHandler connectionHandler = new ConnectionHandler(new EdgeAmqpConnection(amqpConnection), this.connectionProvider); + IClientConnectionsHandler connectionHandler = new ClientConnectionsHandler(this.connectionProvider); amqpConnection.Extensions.Add(connectionHandler); } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/CbsNode.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/CbsNode.cs index 969ac93a576..a1893654db4 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/CbsNode.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/CbsNode.cs @@ -21,7 +21,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp /// This class is used to get tokens from the Client on the CBS link. It generates /// an identity from the received token and authenticates it. /// - class CbsNode : ICbsNode + class CbsNode : ICbsNode, IAmqpAuthenticator { static readonly List ResourceTemplates = new List { @@ -71,22 +71,6 @@ public void RegisterLink(IAmqpLink link) Events.LinkRegistered(link); } - // TODO: Temporary implementation - just get the first credentials and return it. - public async Task GetAmqpAuthentication() - { - if (!this.clientCredentialsMap.Any()) - { - throw new InvalidOperationException("No valid credentials found"); - } - - KeyValuePair creds = this.clientCredentialsMap.First(); - if (!creds.Value.IsAuthenticated) - { - creds.Value.IsAuthenticated = await this.authenticator.AuthenticateAsync(creds.Value.ClientCredentials); - } - return new AmqpAuthentication(creds.Value.IsAuthenticated, Option.Some(creds.Value.ClientCredentials)); - } - public async Task AuthenticateAsync(string id) { try @@ -119,7 +103,7 @@ public async Task AuthenticateAsync(string id) Events.ErrorAuthenticatingIdentity(id, e); return false; } - } + } async void OnMessageReceived(AmqpMessage message) { @@ -175,7 +159,6 @@ async Task HandleTokenUpdate(AmqpMessage message) { credentialsInfo.ClientCredentials = clientCredentials; } - if (credentialsInfo.IsAuthenticated) { await this.credentialsCache.Add(clientCredentials); diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ConnectionHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionHandler.cs similarity index 71% rename from edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ConnectionHandler.cs rename to edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionHandler.cs index 1f9297f0353..850ff4e5271 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ConnectionHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionHandler.cs @@ -3,6 +3,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp { using System; using System.Collections.Generic; + using System.Linq; using System.Threading.Tasks; using System.Web; using Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers; @@ -18,91 +19,42 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp /// It maintains the IIdentity and the IDeviceListener for the connection, and provides it to the link handlers. /// It also maintains a registry of the links open on that connection, and makes sure duplicate/invalid links are not opened. /// - class ConnectionHandler : IConnectionHandler + class ClientConnectionHandler : IConnectionHandler { readonly IDictionary registry = new Dictionary(); - bool isInitialized; - IDeviceListener deviceListener; - AmqpAuthentication amqpAuthentication; + readonly IIdentity identity; readonly AsyncLock initializationLock = new AsyncLock(); readonly AsyncLock registryUpdateLock = new AsyncLock(); - readonly IAmqpConnection connection; readonly IConnectionProvider connectionProvider; + Option deviceListener = Option.None(); - public ConnectionHandler(IAmqpConnection connection, IConnectionProvider connectionProvider) + public ClientConnectionHandler(IIdentity identity, IConnectionProvider connectionProvider) { - this.connection = Preconditions.CheckNotNull(connection, nameof(connection)); + this.identity = Preconditions.CheckNotNull(identity, nameof(identity)); this.connectionProvider = Preconditions.CheckNotNull(connectionProvider, nameof(connectionProvider)); } - public async Task GetDeviceListener() + public Task GetDeviceListener() { - await this.EnsureInitialized(); - return this.deviceListener; - } - - public async Task GetAmqpAuthentication() - { - await this.EnsureInitialized(); - return this.amqpAuthentication; - } - - async Task EnsureInitialized() - { - if (!this.isInitialized) - { - using (await this.initializationLock.LockAsync()) - { - if (!this.isInitialized) + return this.deviceListener.Map(d => Task.FromResult(d)) + .GetOrElse( + async () => { - AmqpAuthentication amqpAuth; - // Check if Principal is SaslPrincipal - if (this.connection.Principal is SaslPrincipal saslPrincipal) - { - amqpAuth = saslPrincipal.AmqpAuthentication; - } - else + using (await this.initializationLock.LockAsync()) { - // Else the connection uses CBS authentication. Get AmqpAuthentication from the CbsNode - var cbsNode = this.connection.FindExtension(); - if (cbsNode == null) - { - throw new InvalidOperationException("CbsNode is null"); - } - - amqpAuth = await cbsNode.GetAmqpAuthentication(); + return await this.deviceListener.Map(d => Task.FromResult(d)) + .GetOrElse( + async () => + { + IDeviceListener dl = await this.connectionProvider.GetDeviceListenerAsync(this.identity); + var deviceProxy = new DeviceProxy(this, this.identity); + dl.BindDeviceProxy(deviceProxy); + this.deviceListener = Option.Some(dl); + return dl; + }); } - - if (!amqpAuth.IsAuthenticated) - { - throw new InvalidOperationException("Connection not authenticated"); - } - - IClientCredentials clientCredentials = amqpAuth.ClientCredentials.Expect(() => new InvalidOperationException("Authenticated connection should have a valid identity")); - this.deviceListener = await this.connectionProvider.GetDeviceListenerAsync(clientCredentials.Identity); - var deviceProxy = new DeviceProxy(this, clientCredentials.Identity); - this.deviceListener.BindDeviceProxy(deviceProxy); - this.amqpAuthentication = amqpAuth; - this.isInitialized = true; - Events.InitializedConnectionHandler(clientCredentials.Identity); - } - } - } - } - - async Task> GetUpdatedAuthenticatedIdentity() - { - var cbsNode = this.connection.FindExtension(); - if (cbsNode != null) - { - AmqpAuthentication updatedAmqpAuthentication = await cbsNode.GetAmqpAuthentication(); - if (updatedAmqpAuthentication.IsAuthenticated) - { - return updatedAmqpAuthentication.ClientCredentials; - } - } - return Option.None(); + }); } public async Task RegisterLinkHandler(ILinkHandler linkHandler) @@ -170,23 +122,29 @@ public async Task RemoveLinkHandler(ILinkHandler linkHandler) } } + Task CloseAllLinks() + { + IList links = this.registry.Values.ToList(); + IEnumerable closeTasks = links.Select(l => l.CloseAsync(Constants.DefaultTimeout)); + return Task.WhenAll(closeTasks); + } + async Task CloseConnection() { using (await this.initializationLock.LockAsync()) { - this.isInitialized = false; - await (this.deviceListener?.CloseAsync() ?? Task.CompletedTask); + await this.deviceListener.ForEachAsync(d => d.CloseAsync()); } } public class DeviceProxy : IDeviceProxy { - readonly ConnectionHandler connectionHandler; + readonly ClientConnectionHandler clientConnectionHandler; readonly AtomicBoolean isActive = new AtomicBoolean(true); - public DeviceProxy(ConnectionHandler connectionHandler, IIdentity identity) + public DeviceProxy(ClientConnectionHandler clientConnectionHandler, IIdentity identity) { - this.connectionHandler = connectionHandler; + this.clientConnectionHandler = clientConnectionHandler; this.Identity = identity; } @@ -195,14 +153,14 @@ public Task CloseAsync(Exception ex) if (this.isActive.GetAndSet(false)) { Events.ClosingProxy(this.Identity, ex); - return this.connectionHandler.connection.Close(); + return this.clientConnectionHandler.CloseAllLinks(); } return Task.CompletedTask; } public Task SendC2DMessageAsync(IMessage message) { - if (!this.connectionHandler.registry.TryGetValue(LinkType.C2D, out ILinkHandler linkHandler)) + if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.C2D, out ILinkHandler linkHandler)) { Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "C2D message"); return Task.CompletedTask; @@ -216,7 +174,7 @@ public Task SendC2DMessageAsync(IMessage message) public Task SendMessageAsync(IMessage message, string input) { - if (!this.connectionHandler.registry.TryGetValue(LinkType.ModuleMessages, out ILinkHandler linkHandler)) + if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.ModuleMessages, out ILinkHandler linkHandler)) { Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "message"); return Task.CompletedTask; @@ -228,7 +186,7 @@ public Task SendMessageAsync(IMessage message, string input) public async Task InvokeMethodAsync(DirectMethodRequest request) { - if (!this.connectionHandler.registry.TryGetValue(LinkType.MethodSending, out ILinkHandler linkHandler)) + if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.MethodSending, out ILinkHandler linkHandler)) { Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "method request"); return default(DirectMethodResponse); @@ -251,7 +209,7 @@ public async Task InvokeMethodAsync(DirectMethodRequest re public Task OnDesiredPropertyUpdates(IMessage desiredProperties) { - if (!this.connectionHandler.registry.TryGetValue(LinkType.TwinSending, out ILinkHandler linkHandler)) + if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.TwinSending, out ILinkHandler linkHandler)) { Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "desired properties update"); return Task.CompletedTask; @@ -263,7 +221,7 @@ public Task OnDesiredPropertyUpdates(IMessage desiredProperties) public Task SendTwinUpdate(IMessage twin) { - if (!this.connectionHandler.registry.TryGetValue(LinkType.TwinSending, out ILinkHandler linkHandler)) + if (!this.clientConnectionHandler.registry.TryGetValue(LinkType.TwinSending, out ILinkHandler linkHandler)) { Events.LinkNotFound(LinkType.ModuleMessages, this.Identity, "twin update"); return Task.CompletedTask; @@ -283,12 +241,12 @@ public void SetInactive() this.isActive.Set(false); } - public Task> GetUpdatedIdentity() => this.connectionHandler.GetUpdatedAuthenticatedIdentity(); + public Task> GetUpdatedIdentity() => throw new NotImplementedException(); } static class Events { - static readonly ILogger Log = Logger.Factory.CreateLogger(); + static readonly ILogger Log = Logger.Factory.CreateLogger(); const int IdStart = AmqpEventIds.ConnectionHandler; enum EventIds diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionsHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionsHandler.cs new file mode 100644 index 00000000000..3d9613f9e44 --- /dev/null +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ClientConnectionsHandler.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.Azure.Devices.Edge.Hub.Amqp +{ + using System.Collections.Concurrent; + using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; + using Microsoft.Azure.Devices.Edge.Util; + + class ClientConnectionsHandler : IClientConnectionsHandler + { + readonly ConcurrentDictionary connectionHandlers = new ConcurrentDictionary(); + readonly IConnectionProvider connectionProvider; + + public ClientConnectionsHandler(IConnectionProvider connectionProvider) + { + this.connectionProvider = Preconditions.CheckNotNull(connectionProvider, nameof(connectionProvider)); + } + + public IConnectionHandler GetConnectionHandler(IIdentity identity) => + this.connectionHandlers.GetOrAdd(identity.Id, i => new ClientConnectionHandler(identity, this.connectionProvider)); + } +} diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/EdgeHubSaslPlainAuthenticator.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/EdgeHubSaslPlainAuthenticator.cs index 404c43fd5c2..cdf217efeea 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/EdgeHubSaslPlainAuthenticator.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/EdgeHubSaslPlainAuthenticator.cs @@ -52,7 +52,7 @@ public async Task AuthenticateAsync(string identity, string password throw new EdgeHubConnectionException("Authentication failed."); } - return new SaslPrincipal(new AmqpAuthentication(true, Option.Some(deviceIdentity))); + return new SaslPrincipal(true, deviceIdentity); } catch (Exception ex) when (!ex.IsFatal()) { diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IAmqpAuthenticator.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IAmqpAuthenticator.cs index f6d9a53cc46..9eb7e2d8f7d 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IAmqpAuthenticator.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IAmqpAuthenticator.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Hub.Amqp { using System.Threading.Tasks; diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ICbsNode.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ICbsNode.cs index c42d59e0c4d..a7bf42e0591 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ICbsNode.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/ICbsNode.cs @@ -2,12 +2,9 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp { using System; - using System.Threading.Tasks; - public interface ICbsNode : IAmqpAuthenticator, IDisposable + public interface ICbsNode : IDisposable { void RegisterLink(IAmqpLink link); - - Task GetAmqpAuthentication(); } } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IClientConnectionsHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IClientConnectionsHandler.cs new file mode 100644 index 00000000000..2ea7c9f0923 --- /dev/null +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IClientConnectionsHandler.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.Azure.Devices.Edge.Hub.Amqp +{ + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; + + public interface IClientConnectionsHandler + { + IConnectionHandler GetConnectionHandler(IIdentity identity); + } +} diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IConnectionHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IConnectionHandler.cs index 089d6efd34e..507cdfb891a 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IConnectionHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/IConnectionHandler.cs @@ -10,8 +10,6 @@ public interface IConnectionHandler { Task GetDeviceListener(); - Task GetAmqpAuthentication(); - Task RegisterLinkHandler(ILinkHandler linkHandler); Task RemoveLinkHandler(ILinkHandler linkHandler); diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/SaslPrincipal.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/SaslPrincipal.cs index 1413c7a35be..5e39970fdbc 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/SaslPrincipal.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/SaslPrincipal.cs @@ -4,20 +4,30 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp { using System; using System.Security.Principal; + using System.Threading.Tasks; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Util; + using IIdentity = System.Security.Principal.IIdentity; - class SaslPrincipal : IPrincipal + class SaslPrincipal : IPrincipal, IAmqpAuthenticator { - public SaslPrincipal(AmqpAuthentication amqpAuthentication) + readonly IClientCredentials clientCredentials; + readonly bool isAuthenticated; + + public SaslPrincipal(bool isAuthenticated, IClientCredentials clientCredentials) { - this.AmqpAuthentication = Preconditions.CheckNotNull(amqpAuthentication, nameof(amqpAuthentication)); - this.Identity = new GenericIdentity(amqpAuthentication.ClientCredentials.Map(i => i.Identity.Id).GetOrElse(string.Empty)); + this.isAuthenticated = isAuthenticated; + this.clientCredentials = Preconditions.CheckNotNull(clientCredentials, nameof(clientCredentials)); + this.Identity = new GenericIdentity(this.clientCredentials.Identity.Id); } - public AmqpAuthentication AmqpAuthentication { get; } - public IIdentity Identity { get; } public bool IsInRole(string role) => throw new NotImplementedException(); + + public Task AuthenticateAsync(string id) => + Task.FromResult( + this.isAuthenticated && + this.clientCredentials.Identity.Id.Equals(id)); } } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/Templates.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/Templates.cs index 9364c7db16c..5cca44c5df4 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/Templates.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/Templates.cs @@ -6,17 +6,12 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp public static class Templates { - public const string IoTHubAliasRootPrefix = "/$iothub"; - public const string DevicePathPrefix = "/devices/"; public const string ModulesPathPrefix = "/modules/"; public const string DeviceIdTemplateParameterName = "deviceid"; public const string ModuleIdTemplateParameterName = "moduleid"; - const string TelemetryEventHubReceiveRedirectPrefix = "/messages/events"; - const string OperationMonitoringEventHubReceiveRedirectPrefix = "/messages/operationsMonitoringEvents"; - public const string DeviceTelemetryStreamUriFormat = "/devices/{0}/messages/events"; public const string ModuleTelemetryStreamUriFormat = "/devices/{0}/modules/{1}/messages/events"; @@ -30,11 +25,6 @@ public static class Templates public static readonly UriPathTemplate ModuleEventsTemplate = new UriPathTemplate(ModuleTelemetryStreamUriFormat.FormatInvariant("{" + DeviceIdTemplateParameterName + "}", "{" + ModuleIdTemplateParameterName + "}")); public static readonly UriPathTemplate DeviceFromDeviceBoundTemplate = new UriPathTemplate(DeviceC2DStreamUriFormat.FormatInvariant("{" + DeviceIdTemplateParameterName + "}")); public static readonly UriPathTemplate ModuleFromDeviceBoundTemplate = new UriPathTemplate(ModuleC2DStreamUriFormat.FormatInvariant("{" + DeviceIdTemplateParameterName + "}", "{" + ModuleIdTemplateParameterName + "}")); - public static readonly UriPathTemplate ServiceToDeviceBoundTemplate = new UriPathTemplate("/messages/deviceBound"); - public static readonly UriPathTemplate FeedbackTemplate = new UriPathTemplate("/messages/serviceBound/feedback"); - public static readonly UriPathTemplate FileNotificationTemplate = new UriPathTemplate("/messages/serviceBound/filenotifications"); - public static readonly UriPathTemplate EventHubReceiveRedirectTemplate = new UriPathTemplate(TelemetryEventHubReceiveRedirectPrefix + "/*"); - public static readonly UriPathTemplate OperationMonitoringEventHubReceiveRedirectTemplate = new UriPathTemplate(OperationMonitoringEventHubReceiveRedirectPrefix + "/*"); public static class Twin { @@ -46,7 +36,6 @@ public static class Twin public static readonly UriPathTemplate ModuleDeviceBoundMethodCallTemplate = new UriPathTemplate(ModuleDeviceBoundMethodCallUriFormat.FormatInvariant("{" + DeviceIdTemplateParameterName + "}", "{" + ModuleIdTemplateParameterName + "}")); public static readonly UriPathTemplate TwinStreamTemplate = new UriPathTemplate(DeviceTwinMessageStreamUriFormat.FormatInvariant("{" + DeviceIdTemplateParameterName + "}")); public static readonly UriPathTemplate ModuleTwinStreamTemplate = new UriPathTemplate(ModuleTwinMessageStreamUriFormat.FormatInvariant("{" + DeviceIdTemplateParameterName + "}", "{" + ModuleIdTemplateParameterName + "}")); - public static readonly UriPathTemplate RootTwinStreamTemplate = new UriPathTemplate(IoTHubAliasRootPrefix + DeviceTwinMessageStreamUriFormat.FormatInvariant("{" + DeviceIdTemplateParameterName + "}")); } } } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/DeviceBoundLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/DeviceBoundLinkHandler.cs index d4266fc1196..7ab5fdb65a2 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/DeviceBoundLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/DeviceBoundLinkHandler.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers { using System; @@ -7,15 +8,21 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using Microsoft.Azure.Amqp; using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; /// /// Address matches the template "/devices/{0}/messages/deviceBound" /// public class DeviceBoundLinkHandler : SendingLinkHandler { - public DeviceBoundLinkHandler(ISendingAmqpLink link, Uri requestUri, IDictionary boundVariables, + public DeviceBoundLinkHandler( + IIdentity identity, + ISendingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { } @@ -28,8 +35,7 @@ protected override async Task OnOpenAsync(TimeSpan timeout) // TODO: Check if we need to worry about credit available on the link await base.OnOpenAsync(timeout); - // TODO: Temporary fix since SDK subscribes to C2D messages for modules. - if (string.IsNullOrWhiteSpace(this.ModuleId)) + if (!(this.Identity is IModuleIdentity)) { await this.DeviceListener.AddSubscription(DeviceSubscription.C2D); } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/EventsLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/EventsLinkHandler.cs index aa4a836b705..780d5be9606 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/EventsLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/EventsLinkHandler.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers { using System; @@ -8,6 +9,8 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using Microsoft.Azure.Amqp; using Microsoft.Azure.Amqp.Framing; using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Device; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Extensions.Logging; @@ -19,9 +22,14 @@ class EventsLinkHandler : ReceivingLinkHandler { static readonly long MaxBatchedMessageSize = 600 * 1024; - public EventsLinkHandler(IReceivingAmqpLink link, Uri requestUri, IDictionary boundVariables, + public EventsLinkHandler( + IIdentity identity, + IReceivingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { } @@ -63,14 +71,15 @@ protected override async Task OnMessageReceived(AmqpMessage amqpMessage) void AddMessageSystemProperties(IMessage message) { - if (!string.IsNullOrWhiteSpace(this.DeviceId)) + if (this.Identity is IDeviceIdentity deviceIdentity) { - message.SystemProperties[SystemProperties.ConnectionDeviceId] = this.DeviceId; + message.SystemProperties[SystemProperties.ConnectionDeviceId] = deviceIdentity.DeviceId; } - if (!string.IsNullOrWhiteSpace(this.ModuleId)) + if (this.Identity is IModuleIdentity moduleIdentity) { - message.SystemProperties[SystemProperties.ConnectionModuleId] = this.ModuleId; + message.SystemProperties[SystemProperties.ConnectionDeviceId] = moduleIdentity.DeviceId; + message.SystemProperties[SystemProperties.ConnectionModuleId] = moduleIdentity.ModuleId; } } @@ -83,10 +92,12 @@ internal static IList ExpandBatchedMessage(AmqpMessage message) foreach (Data data in message.DataBody) { var payload = (ArraySegment)data.Value; - AmqpMessage debatchedMessage = AmqpMessage.CreateAmqpStreamMessage(new BufferListStream(new List>() - { - payload - })); + AmqpMessage debatchedMessage = AmqpMessage.CreateAmqpStreamMessage( + new BufferListStream( + new List>() + { + payload + })); outputMessages.Add(debatchedMessage); } } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/IReceivingLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/IReceivingLinkHandler.cs index fe6d0f032c3..1b0d7716864 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/IReceivingLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/IReceivingLinkHandler.cs @@ -3,5 +3,6 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers { public interface IReceivingLinkHandler : ILinkHandler - { } + { + } } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/LinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/LinkHandler.cs index a531fc9ea4e..fb98120bf11 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/LinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/LinkHandler.cs @@ -15,42 +15,35 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers public abstract class LinkHandler : ILinkHandler { - IDeviceListener deviceListener; - - protected LinkHandler(IAmqpLink link, Uri requestUri, - IDictionary boundVariables, IMessageConverter messageConverter) + readonly IConnectionHandler connectionHandler; + + protected LinkHandler( + IIdentity identity, + IAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, + IMessageConverter messageConverter) { - // TODO: IoT Hub periodically validates that the authorization is still valid in this - // class using a timer (except when the concrete sub-class is CbsLinkHandler or EventHubReceiveRedirectLinkHandler. - // We need to evaluate whether it makes sense to do that in Edge Hub too. See the implementation in - // AmqpGatewayProtocolHead.LinkHandler.IotHubStatusTimerCallback in service code. - + this.Identity = Preconditions.CheckNotNull(identity, nameof(identity)); this.MessageConverter = Preconditions.CheckNotNull(messageConverter, nameof(messageConverter)); this.BoundVariables = Preconditions.CheckNotNull(boundVariables, nameof(boundVariables)); this.Link = Preconditions.CheckNotNull(link, nameof(link)); this.LinkUri = Preconditions.CheckNotNull(requestUri, nameof(requestUri)); this.Link.SafeAddClosed(this.OnLinkClosed); - this.ConnectionHandler = this.Link.Session.Connection.FindExtension(); - this.DeviceId = this.BoundVariables.ContainsKey(Templates.DeviceIdTemplateParameterName) ? this.BoundVariables[Templates.DeviceIdTemplateParameterName] : string.Empty; - this.ModuleId = this.BoundVariables.ContainsKey(Templates.ModuleIdTemplateParameterName) ? this.BoundVariables[Templates.ModuleIdTemplateParameterName] : string.Empty; + this.connectionHandler = Preconditions.CheckNotNull(connectionHandler, nameof(connectionHandler)); } - protected string DeviceId { get; } - - protected string ModuleId { get; } + protected IIdentity Identity { get; } - protected string ClientId => this.DeviceId + (!string.IsNullOrWhiteSpace(this.ModuleId) ? $"/{this.ModuleId}" : string.Empty); + protected string ClientId => this.Identity.Id; protected IMessageConverter MessageConverter { get; } - protected IDeviceListener DeviceListener => this.deviceListener; + protected IDeviceListener DeviceListener { get; private set; } protected IDictionary BoundVariables { get; } - protected IIdentity Identity => this.deviceListener?.Identity; - - protected IConnectionHandler ConnectionHandler { get; } - public IAmqpLink Link { get; } public Uri LinkUri { get; } @@ -61,28 +54,46 @@ protected LinkHandler(IAmqpLink link, Uri requestUri, public async Task OpenAsync(TimeSpan timeout) { - if (!this.Link.IsCbsLink()) + if (!await this.Authenticate()) { - if (!await this.Authenticate()) - { - throw new InvalidOperationException($"Unable to open {this.Type} link as connection is not authenticated"); - } - - this.deviceListener = await this.ConnectionHandler.GetDeviceListener(); + throw new InvalidOperationException($"Unable to open {this.Type} link as the connection could not be authenticated"); } + + this.DeviceListener = await this.connectionHandler.GetDeviceListener(); + await this.OnOpenAsync(timeout); - await this.ConnectionHandler.RegisterLinkHandler(this); + await this.connectionHandler.RegisterLinkHandler(this); Events.Opened(this); } protected abstract Task OnOpenAsync(TimeSpan timeout); - protected async Task Authenticate() => (await this.ConnectionHandler.GetAmqpAuthentication()).IsAuthenticated; + protected Task Authenticate() + { + IAmqpAuthenticator amqpAuth; + IAmqpConnection connection = this.Link.Session.Connection; + + // Check if Principal is IAmqpAuthenticator + if (connection.Principal is IAmqpAuthenticator connAuth) + { + amqpAuth = connAuth; + } + else if (connection.FindExtension() is IAmqpAuthenticator cbsAuth) + { + amqpAuth = cbsAuth; + } + else + { + throw new InvalidOperationException($"Unable to find authentication mechanism for AMQP connection for identity {this.Identity.Id}"); + } + + return amqpAuth.AuthenticateAsync(this.Identity.Id); + } protected virtual void OnLinkClosed(object sender, EventArgs args) { Events.Closed(this); - this.ConnectionHandler.RemoveLinkHandler(this); + this.connectionHandler.RemoveLinkHandler(this); } public async Task CloseAsync(TimeSpan timeout) diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/LinkHandlerProvider.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/LinkHandlerProvider.cs index 3e28e437791..22b4fcc16a1 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/LinkHandlerProvider.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/LinkHandlerProvider.cs @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using System.Collections.Generic; using Microsoft.Azure.Amqp; using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Util; public class LinkHandlerProvider : ILinkHandlerProvider @@ -14,40 +15,47 @@ public class LinkHandlerProvider : ILinkHandlerProvider { { (Templates.CbsReceiveTemplate, true), LinkType.Cbs }, { (Templates.CbsReceiveTemplate, false), LinkType.Cbs }, - { ( Templates.DeviceEventsTemplate, true), LinkType.Events }, - { ( Templates.ModuleEventsTemplate, true), LinkType.Events }, - { ( Templates.ModuleEventsTemplate, false), LinkType.ModuleMessages }, - { ( Templates.DeviceFromDeviceBoundTemplate, false), LinkType.C2D }, - { ( Templates.ModuleFromDeviceBoundTemplate, false), LinkType.C2D }, - { ( Templates.Twin.DeviceBoundMethodCallTemplate, true), LinkType.MethodReceiving }, - { ( Templates.Twin.ModuleDeviceBoundMethodCallTemplate, true), LinkType.MethodReceiving }, - { ( Templates.Twin.DeviceBoundMethodCallTemplate, false), LinkType.MethodSending }, - { ( Templates.Twin.ModuleDeviceBoundMethodCallTemplate, false), LinkType.MethodSending }, - { ( Templates.Twin.TwinStreamTemplate, true), LinkType.TwinReceiving }, - { ( Templates.Twin.ModuleTwinStreamTemplate, true), LinkType.TwinReceiving }, - { ( Templates.Twin.TwinStreamTemplate, false), LinkType.TwinSending }, - { ( Templates.Twin.ModuleTwinStreamTemplate, false), LinkType.TwinSending }, + { (Templates.DeviceEventsTemplate, true), LinkType.Events }, + { (Templates.ModuleEventsTemplate, true), LinkType.Events }, + { (Templates.ModuleEventsTemplate, false), LinkType.ModuleMessages }, + { (Templates.DeviceFromDeviceBoundTemplate, false), LinkType.C2D }, + { (Templates.ModuleFromDeviceBoundTemplate, false), LinkType.C2D }, + { (Templates.Twin.DeviceBoundMethodCallTemplate, true), LinkType.MethodReceiving }, + { (Templates.Twin.ModuleDeviceBoundMethodCallTemplate, true), LinkType.MethodReceiving }, + { (Templates.Twin.DeviceBoundMethodCallTemplate, false), LinkType.MethodSending }, + { (Templates.Twin.ModuleDeviceBoundMethodCallTemplate, false), LinkType.MethodSending }, + { (Templates.Twin.TwinStreamTemplate, true), LinkType.TwinReceiving }, + { (Templates.Twin.ModuleTwinStreamTemplate, true), LinkType.TwinReceiving }, + { (Templates.Twin.TwinStreamTemplate, false), LinkType.TwinSending }, + { (Templates.Twin.ModuleTwinStreamTemplate, false), LinkType.TwinSending }, }; readonly IMessageConverter messageConverter; readonly IMessageConverter twinMessageConverter; readonly IMessageConverter methodMessageConverter; + readonly IIdentityProvider identityProvider; readonly IDictionary<(UriPathTemplate Template, bool IsReceiver), LinkType> templatesList; - public LinkHandlerProvider(IMessageConverter messageConverter, + public LinkHandlerProvider( + IMessageConverter messageConverter, IMessageConverter twinMessageConverter, - IMessageConverter methodMessageConverter) - : this(messageConverter, twinMessageConverter, methodMessageConverter, DefaultTemplatesList) - { } + IMessageConverter methodMessageConverter, + IIdentityProvider identityProvider) + : this(messageConverter, twinMessageConverter, methodMessageConverter, identityProvider, DefaultTemplatesList) + { + } - public LinkHandlerProvider(IMessageConverter messageConverter, + public LinkHandlerProvider( + IMessageConverter messageConverter, IMessageConverter twinMessageConverter, IMessageConverter methodMessageConverter, + IIdentityProvider identityProvider, IDictionary<(UriPathTemplate Template, bool IsReceiver), LinkType> templatesList) { this.messageConverter = Preconditions.CheckNotNull(messageConverter, nameof(messageConverter)); this.twinMessageConverter = Preconditions.CheckNotNull(twinMessageConverter, nameof(twinMessageConverter)); this.methodMessageConverter = Preconditions.CheckNotNull(methodMessageConverter, nameof(methodMessageConverter)); + this.identityProvider = Preconditions.CheckNotNull(identityProvider, nameof(identityProvider)); this.templatesList = Preconditions.CheckNotNull(templatesList, nameof(templatesList)); } @@ -63,35 +71,64 @@ public ILinkHandler Create(IAmqpLink link, Uri uri) internal ILinkHandler GetLinkHandler(LinkType linkType, IAmqpLink link, Uri uri, IDictionary boundVariables) { - switch (linkType) + if (linkType == LinkType.Cbs) + { + return CbsLinkHandler.Create(link, uri); + } + else { - case LinkType.Cbs: - return CbsLinkHandler.Create(link, uri); + IIdentity identity = this.GetIdentity(boundVariables); + IConnectionHandler connectionHandler = this.GetConnectionHandler(link, identity); + switch (linkType) + { + case LinkType.C2D: + return new DeviceBoundLinkHandler(identity, link as ISendingAmqpLink, uri, boundVariables, connectionHandler, this.messageConverter); - case LinkType.C2D: - return new DeviceBoundLinkHandler(link as ISendingAmqpLink, uri, boundVariables, this.messageConverter); + case LinkType.Events: + return new EventsLinkHandler(identity, link as IReceivingAmqpLink, uri, boundVariables, connectionHandler, this.messageConverter); - case LinkType.Events: - return new EventsLinkHandler(link as IReceivingAmqpLink, uri, boundVariables, this.messageConverter); + case LinkType.ModuleMessages: + return new ModuleMessageLinkHandler(identity, link as ISendingAmqpLink, uri, boundVariables, connectionHandler, this.messageConverter); - case LinkType.ModuleMessages: - return new ModuleMessageLinkHandler(link as ISendingAmqpLink, uri, boundVariables, this.messageConverter); + case LinkType.MethodSending: + return new MethodSendingLinkHandler(identity, link as ISendingAmqpLink, uri, boundVariables, connectionHandler, this.methodMessageConverter); - case LinkType.MethodSending: - return new MethodSendingLinkHandler(link as ISendingAmqpLink, uri, boundVariables, this.methodMessageConverter); + case LinkType.MethodReceiving: + return new MethodReceivingLinkHandler(identity, link as IReceivingAmqpLink, uri, boundVariables, connectionHandler, this.methodMessageConverter); - case LinkType.MethodReceiving: - return new MethodReceivingLinkHandler(link as IReceivingAmqpLink, uri, boundVariables, this.methodMessageConverter); + case LinkType.TwinReceiving: + return new TwinReceivingLinkHandler(identity, link as IReceivingAmqpLink, uri, boundVariables, connectionHandler, this.twinMessageConverter); - case LinkType.TwinReceiving: - return new TwinReceivingLinkHandler(link as IReceivingAmqpLink, uri, boundVariables, this.twinMessageConverter); + case LinkType.TwinSending: + return new TwinSendingLinkHandler(identity, link as ISendingAmqpLink, uri, boundVariables, connectionHandler, this.twinMessageConverter); - case LinkType.TwinSending: - return new TwinSendingLinkHandler(link as ISendingAmqpLink, uri, boundVariables, this.twinMessageConverter); + default: + throw new InvalidOperationException($"Invalid link type {linkType}"); + } + } + } - default: - throw new InvalidOperationException($"Invalid link type {linkType}"); + IConnectionHandler GetConnectionHandler(IAmqpLink link, IIdentity identity) + { + var amqpClientConnectionsHandler = link.Session.Connection.FindExtension(); + if (amqpClientConnectionsHandler == null) + { + throw new InvalidOperationException("Expected extension IAmqpClientConnectionsHandler not found on connection"); } + + return amqpClientConnectionsHandler.GetConnectionHandler(identity); + } + + IIdentity GetIdentity(IDictionary boundVariables) + { + if (!boundVariables.TryGetValue(Templates.DeviceIdTemplateParameterName, out string deviceId)) + { + throw new InvalidOperationException("Link should contain a device Id"); + } + + return boundVariables.TryGetValue(Templates.ModuleIdTemplateParameterName, out string moduleId) + ? this.identityProvider.Create(deviceId, moduleId) + : this.identityProvider.Create(deviceId); } internal (LinkType LinkType, IDictionary BoundVariables) GetLinkType(IAmqpLink link, Uri uri) @@ -103,6 +140,7 @@ internal ILinkHandler GetLinkHandler(LinkType linkType, IAmqpLink link, Uri uri, return (this.templatesList[key], boundVariables.ToDictionary()); } } + throw new InvalidOperationException($"Matching template not found for uri {uri}"); } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/MethodReceivingLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/MethodReceivingLinkHandler.cs index f52e67d32cf..17441ae7a31 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/MethodReceivingLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/MethodReceivingLinkHandler.cs @@ -7,6 +7,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using System.Threading.Tasks; using Microsoft.Azure.Amqp; using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; /// /// This class handles direct method responses from the client. @@ -15,9 +16,14 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers /// public class MethodReceivingLinkHandler : ReceivingLinkHandler { - public MethodReceivingLinkHandler(IReceivingAmqpLink link, Uri requestUri, IDictionary boundVariables, + public MethodReceivingLinkHandler( + IIdentity identity, + IReceivingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/MethodSendingLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/MethodSendingLinkHandler.cs index 7b1d497ac03..2f3532f2e7a 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/MethodSendingLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/MethodSendingLinkHandler.cs @@ -8,6 +8,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using Microsoft.Azure.Amqp; using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; /// /// This handles direct method requests to the client. @@ -16,9 +17,14 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers /// public class MethodSendingLinkHandler : SendingLinkHandler { - public MethodSendingLinkHandler(ISendingAmqpLink link, Uri requestUri, IDictionary boundVariables, + public MethodSendingLinkHandler( + IIdentity identity, + ISendingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/ModuleMessageLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/ModuleMessageLinkHandler.cs index b535b1ff624..99bd3c7d875 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/ModuleMessageLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/ModuleMessageLinkHandler.cs @@ -8,6 +8,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using Microsoft.Azure.Amqp; using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; /// /// This handler is used to send messages to modules @@ -15,9 +16,14 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers /// class ModuleMessageLinkHandler : SendingLinkHandler { - public ModuleMessageLinkHandler(ISendingAmqpLink link, Uri requestUri, IDictionary boundVariables, + public ModuleMessageLinkHandler( + IIdentity identity, + ISendingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/ReceivingLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/ReceivingLinkHandler.cs index 22116f0ed68..a74d342de91 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/ReceivingLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/ReceivingLinkHandler.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers { using System; @@ -8,6 +9,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using Microsoft.Azure.Amqp; using Microsoft.Azure.Amqp.Framing; using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Extensions.Logging; @@ -18,9 +20,14 @@ public abstract class ReceivingLinkHandler : LinkHandler, IReceivingLinkHandler { readonly ActionBlock sendMessageProcessor; - protected ReceivingLinkHandler(IReceivingAmqpLink link, Uri requestUri, IDictionary boundVariables, + protected ReceivingLinkHandler( + IIdentity identity, + IReceivingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { Preconditions.CheckArgument(link.IsReceiver, $"Link {requestUri} cannot receive"); this.ReceivingLink = link; @@ -39,27 +46,28 @@ protected override Task OnOpenAsync(TimeSpan timeout) // The receiver will only settle after sending the disposition to the sender and receiving a disposition indicating settlement of the delivery from the sender. this.ReceivingLink.Settings.RcvSettleMode = (byte)ReceiverSettleMode.Second; // SenderSettleMode.Unsettled (null as it is the default and to avoid bytes on the wire) - this.ReceivingLink.Settings.SndSettleMode = null; + this.ReceivingLink.Settings.SndSettleMode = null; break; case QualityOfService.AtLeastOnce: // The Receiver will spontaneously settle all incoming transfers. - this.ReceivingLink.Settings.RcvSettleMode = null;// Default ReceiverSettleMode.First; + this.ReceivingLink.Settings.RcvSettleMode = null; // Default ReceiverSettleMode.First; // The Sender will send all deliveries unsettled to the receiver. this.ReceivingLink.Settings.SndSettleMode = null; // Default SenderSettleMode.Unettled; break; case QualityOfService.AtMostOnce: // The Receiver will spontaneously settle all incoming transfers. - this.ReceivingLink.Settings.RcvSettleMode = null;// Default ReceiverSettleMode.First; + this.ReceivingLink.Settings.RcvSettleMode = null; // Default ReceiverSettleMode.First; // The Sender will send all deliveries unsettled to the receiver. this.ReceivingLink.Settings.SndSettleMode = (byte)SenderSettleMode.Settled; break; } this.ReceivingLink.RegisterMessageListener(m => this.sendMessageProcessor.Post(m)); - this.ReceivingLink.SafeAddClosed((s, e) => this.OnReceiveLinkClosed() - .ContinueWith(t => Events.ErrorClosingLink(t.Exception, this), TaskContinuationOptions.OnlyOnFaulted)); + this.ReceivingLink.SafeAddClosed( + (s, e) => this.OnReceiveLinkClosed() + .ContinueWith(t => Events.ErrorClosingLink(t.Exception, this), TaskContinuationOptions.OnlyOnFaulted)); return Task.CompletedTask; } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/SendingLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/SendingLinkHandler.cs index 8938b8ffccb..aebaa5b32f4 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/SendingLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/SendingLinkHandler.cs @@ -5,10 +5,10 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using System; using System.Collections.Generic; using System.Threading.Tasks; - using System.Threading.Tasks.Dataflow; using Microsoft.Azure.Amqp; using Microsoft.Azure.Amqp.Framing; using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Extensions.Logging; @@ -17,9 +17,14 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers /// public abstract class SendingLinkHandler : LinkHandler, ISendingLinkHandler { - protected SendingLinkHandler(ISendingAmqpLink link, Uri requestUri, - IDictionary boundVariables, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + protected SendingLinkHandler( + IIdentity identity, + ISendingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, + IMessageConverter messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { Preconditions.CheckArgument(!link.IsReceiver, $"Link {requestUri} cannot send"); this.SendingAmqpLink = link; @@ -56,6 +61,7 @@ protected override Task OnOpenAsync(TimeSpan timeout) this.SendingAmqpLink.Settings.SndSettleMode = (byte)SenderSettleMode.Settled; break; } + return Task.CompletedTask; } @@ -81,12 +87,14 @@ public Task SendMessage(IMessage message) { return this.SendingAmqpLink.SendMessageAsync(amqpMessage, deliveryTag, AmqpConstants.NullBinary, Amqp.Constants.DefaultTimeout); } + Events.MessageSent(this, message); } catch (Exception ex) { Events.ErrorProcessingMessage(ex, this); } + return Task.CompletedTask; } @@ -167,8 +175,10 @@ string GetMessageId() { messageId = string.Empty; } + return messageId; } + Log.LogDebug((int)EventIds.MessageSent, $"Sent message with id {GetMessageId()} to device {handler.ClientId}"); } } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/TwinReceivingLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/TwinReceivingLinkHandler.cs index 889f7e2cc66..9fdb504f05b 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/TwinReceivingLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/TwinReceivingLinkHandler.cs @@ -7,6 +7,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using System.Threading.Tasks; using Microsoft.Azure.Amqp; using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Extensions.Logging; @@ -22,11 +23,13 @@ public class TwinReceivingLinkHandler : ReceivingLinkHandler public const string TwinPut = "PUT"; public const string TwinDelete = "DELETE"; - public TwinReceivingLinkHandler(IReceivingAmqpLink link, + public TwinReceivingLinkHandler( + IIdentity identity, + IReceivingAmqpLink link, Uri requestUri, - IDictionary boundVariables, + IDictionary boundVariables, IConnectionHandler connectionHandler, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/TwinSendingLinkHandler.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/TwinSendingLinkHandler.cs index d5f00e8c258..f96fe5e5430 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/TwinSendingLinkHandler.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Amqp/linkhandlers/TwinSendingLinkHandler.cs @@ -8,6 +8,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers using Microsoft.Azure.Amqp; using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; /// /// This class handles sending twin messages to the client (Get twin responses and @@ -16,9 +17,14 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers /// public class TwinSendingLinkHandler : SendingLinkHandler { - public TwinSendingLinkHandler(ISendingAmqpLink link, Uri requestUri, IDictionary boundVariables, + public TwinSendingLinkHandler( + IIdentity identity, + ISendingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs index 06eff166409..7a773ea2a43 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs @@ -29,7 +29,7 @@ public DeviceScopeTokenAuthenticator( base(deviceScopeIdentitiesCache, underlyingAuthenticator, allowDeviceAuthForModule, syncServiceIdentityOnFailure) { this.iothubHostName = Preconditions.CheckNonWhiteSpace(iothubHostName, nameof(iothubHostName)); - this.edgeHubHostName = Preconditions.CheckNonWhiteSpace(edgeHubHostName, nameof(edgeHubHostName)); + this.edgeHubHostName = Preconditions.CheckNotNull(edgeHubHostName, nameof(edgeHubHostName)); } protected override bool AreInputCredentialsValid(ITokenCredentials credentials) => this.TryGetSharedAccessSignature(credentials.Token, credentials.Identity, out SharedAccessSignature _); diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/AmqpModule.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/AmqpModule.cs index a0d507f61c8..55027ee65a0 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/AmqpModule.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/AmqpModule.cs @@ -57,7 +57,8 @@ protected override void Load(ContainerBuilder builder) IMessageConverter messageConverter = new AmqpMessageConverter(); IMessageConverter twinMessageConverter = new AmqpTwinMessageConverter(); IMessageConverter directMethodMessageConverter = new AmqpDirectMethodMessageConverter(); - ILinkHandlerProvider linkHandlerProvider = new LinkHandlerProvider(messageConverter, twinMessageConverter, directMethodMessageConverter); + var identityProvider = c.Resolve(); + ILinkHandlerProvider linkHandlerProvider = new LinkHandlerProvider(messageConverter, twinMessageConverter, directMethodMessageConverter, identityProvider); return linkHandlerProvider; }) .As() diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/CbsNodeTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/CbsNodeTest.cs index 3b116e10045..2680e623f67 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/CbsNodeTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/CbsNodeTest.cs @@ -182,105 +182,12 @@ public async Task UpdateCbsTokenTest() // Act (AmqpResponseStatusCode statusCode, string description) = await cbsNode.UpdateCbsToken(validAmqpMessage); - - AmqpAuthentication amqpAuthentication = await cbsNode.GetAmqpAuthentication(); + bool isAuthenticated = await cbsNode.AuthenticateAsync(identity.Id); // Assert - Assert.Equal(true, amqpAuthentication.IsAuthenticated); - Assert.True(amqpAuthentication.ClientCredentials.HasValue); - Assert.Equal(identity, amqpAuthentication.ClientCredentials.OrDefault().Identity); + Assert.True(isAuthenticated); Assert.Equal(AmqpResponseStatusCode.OK, statusCode); Assert.Equal(AmqpResponseStatusCode.OK.ToString(), description); } - - [Fact] - public async Task HandleMultipleTokensTest() - { - // Arrange - string iotHubHostName = "edgehubtest1.azure-devices.net"; - - var amqpValue1 = new AmqpValue - { - Value = TokenHelper.CreateSasToken("edgehubtest1.azure-devices.net/devices/device1/modules/mod1") - }; - AmqpMessage validAmqpMessage1 = AmqpMessage.Create(amqpValue1); - validAmqpMessage1.ApplicationProperties.Map[CbsConstants.PutToken.Type] = "azure-devices.net:sastoken"; - validAmqpMessage1.ApplicationProperties.Map[CbsConstants.PutToken.Audience] = "iothub"; - validAmqpMessage1.ApplicationProperties.Map[CbsConstants.Operation] = CbsConstants.PutToken.OperationValue; - - var identity1 = Mock.Of(i => i.Id == "device1/mod1"); - var clientCredentials1 = Mock.Of(c => c.Identity == identity1); - var clientCredentialsFactory = new Mock(); - clientCredentialsFactory.Setup(i => i.GetWithSasToken("device1", "mod1", It.IsAny(), It.IsAny(), true)) - .Returns(clientCredentials1); - - var amqpValue2 = new AmqpValue - { - Value = TokenHelper.CreateSasToken("edgehubtest1.azure-devices.net/devices/device1/modules/mod2") - }; - AmqpMessage validAmqpMessage2 = AmqpMessage.Create(amqpValue2); - validAmqpMessage2.ApplicationProperties.Map[CbsConstants.PutToken.Type] = "azure-devices.net:sastoken"; - validAmqpMessage2.ApplicationProperties.Map[CbsConstants.PutToken.Audience] = "iothub"; - validAmqpMessage2.ApplicationProperties.Map[CbsConstants.Operation] = CbsConstants.PutToken.OperationValue; - - var identity2 = Mock.Of(i => i.Id == "device1/mod2"); - var clientCredentials2 = Mock.Of(c => c.Identity == identity2); - clientCredentialsFactory.Setup(i => i.GetWithSasToken("device1", "mod2", It.IsAny(), It.IsAny(), true)) - .Returns(clientCredentials2); - - var authenticator = new Mock(); - authenticator.Setup(a => a.AuthenticateAsync(clientCredentials1)).ReturnsAsync(true); - authenticator.Setup(a => a.AuthenticateAsync(clientCredentials2)).ReturnsAsync(true); - var cbsNode = new CbsNode(clientCredentialsFactory.Object, iotHubHostName, authenticator.Object, new NullCredentialsCache()); - - // Act - (AmqpResponseStatusCode statusCode1, string description1) = await cbsNode.UpdateCbsToken(validAmqpMessage1); - - // Assert - Assert.Equal(AmqpResponseStatusCode.OK, statusCode1); - Assert.Equal(AmqpResponseStatusCode.OK.ToString(), description1); - - // Act - (AmqpResponseStatusCode statusCode2, string description2) = await cbsNode.UpdateCbsToken(validAmqpMessage2); - - // Assert - Assert.Equal(AmqpResponseStatusCode.OK, statusCode2); - Assert.Equal(AmqpResponseStatusCode.OK.ToString(), description2); - - // Act - bool isAuthenticated = await cbsNode.AuthenticateAsync("device1/mod1"); - - // Assert - Assert.True(isAuthenticated); - authenticator.Verify(a => a.AuthenticateAsync(clientCredentials1), Times.Once); - - // Act - isAuthenticated = await cbsNode.AuthenticateAsync("device1/mod1"); - - // Assert - Assert.True(isAuthenticated); - authenticator.Verify(a => a.AuthenticateAsync(clientCredentials1), Times.Once); - - // Act - isAuthenticated = await cbsNode.AuthenticateAsync("device1/mod2"); - - // Assert - Assert.True(isAuthenticated); - authenticator.Verify(a => a.AuthenticateAsync(clientCredentials2), Times.Once); - - // Act - isAuthenticated = await cbsNode.AuthenticateAsync("device1/mod2"); - - // Assert - Assert.True(isAuthenticated); - authenticator.Verify(a => a.AuthenticateAsync(clientCredentials2), Times.Once); - authenticator.Verify(a => a.AuthenticateAsync(clientCredentials1), Times.Once); - - // Act - isAuthenticated = await cbsNode.AuthenticateAsync("device1/mod3"); - - // Assert - Assert.False(isAuthenticated); - } } } diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/ConnectionHandlerTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/ConnectionHandlerTest.cs index beaa67bfd1e..77dcdf320a2 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/ConnectionHandlerTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/ConnectionHandlerTest.cs @@ -9,7 +9,6 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; - using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Azure.Devices.Edge.Util.Test.Common; using Moq; using Xunit; @@ -21,13 +20,13 @@ public class ConnectionHandlerTest public void ConnectionHandlerCtorTest() { // Arrange - var amqpConnection = Mock.Of(); + var identity = Mock.Of(); var connectionPovider = Mock.Of(); // Act / Assert - Assert.NotNull(new ConnectionHandler(amqpConnection, connectionPovider)); - Assert.Throws(() => new ConnectionHandler(null, connectionPovider)); - Assert.Throws(() => new ConnectionHandler(amqpConnection, null)); + Assert.NotNull(new ClientConnectionHandler(identity, connectionPovider)); + Assert.Throws(() => new ClientConnectionHandler(null, connectionPovider)); + Assert.Throws(() => new ClientConnectionHandler(identity, null)); } [Fact] @@ -36,17 +35,12 @@ public async Task GetDeviceListenerTest() // Arrange IDeviceProxy deviceProxy = null; var identity = Mock.Of(i => i.Id == "d1/m1"); - var clientCredentials = Mock.Of(c => c.Identity == identity); var deviceListener = Mock.Of(); Mock.Get(deviceListener).Setup(d => d.BindDeviceProxy(It.IsAny())) .Callback(d => deviceProxy = d); - var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(clientCredentials.Identity) == Task.FromResult(deviceListener)); - - var amqpAuthentication = new AmqpAuthentication(true, Option.Some(clientCredentials)); - var cbsNode = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuthentication)); - var amqpConnection = Mock.Of(c => c.FindExtension() == cbsNode); - var connectionHandler = new ConnectionHandler(amqpConnection, connectionProvider); + var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(identity) == Task.FromResult(deviceListener)); + var connectionHandler = new ClientConnectionHandler(identity, connectionProvider); // Act var tasks = new List>(); @@ -64,64 +58,23 @@ public async Task GetDeviceListenerTest() Assert.Equal(deviceListener, deviceListeners[0]); } Assert.NotNull(deviceProxy); - Mock.Get(connectionProvider).Verify(c => c.GetDeviceListenerAsync(It.IsAny()), Times.AtMostOnce); - Mock.Get(deviceListener).Verify(d => d.BindDeviceProxy(It.IsAny()), Times.AtMostOnce); - } - - [Fact] - public async Task GetAmqpAuthenticationTest() - { - // Arrange - var identity = Mock.Of(i => i.Id == "d1/m1"); - var clientCredentials = Mock.Of(c => c.Identity == identity); - var deviceListener = Mock.Of(); - Mock.Get(deviceListener).Setup(d => d.BindDeviceProxy(It.IsAny())); - - var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(clientCredentials.Identity) == Task.FromResult(deviceListener)); - - var amqpAuthentication = new AmqpAuthentication(true, Option.Some(clientCredentials)); - var cbsNode = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuthentication)); - var amqpConnection = Mock.Of(c => c.FindExtension() == cbsNode); - var connectionHandler = new ConnectionHandler(amqpConnection, connectionProvider); - - // Act - var tasks = new List>(); - for (int i = 0; i < 10; i++) - { - tasks.Add(connectionHandler.GetAmqpAuthentication()); - } - IList amqpAuthentications = (await Task.WhenAll(tasks)).ToList(); - - // Assert - Assert.NotNull(amqpAuthentications); - Assert.Equal(10, amqpAuthentications.Count); - for (int i = 0; i < 10; i++) - { - Assert.Equal(amqpAuthentication, amqpAuthentications[0]); - } - Assert.True(amqpAuthentications[0].IsAuthenticated); - Assert.Equal(identity, amqpAuthentications[0].ClientCredentials.OrDefault().Identity); - Mock.Get(connectionProvider).Verify(c => c.GetDeviceListenerAsync(It.IsAny()), Times.AtMostOnce); - Mock.Get(cbsNode).Verify(d => d.GetAmqpAuthentication(), Times.AtMostOnce); + Mock.Get(connectionProvider).Verify(c => c.GetDeviceListenerAsync(It.IsAny()), Times.Once); + Mock.Get(deviceListener).Verify(d => d.BindDeviceProxy(It.IsAny()), Times.Once); } - + [Fact] public async Task RegisterC2DMessageSenderTest() { // Arrange IDeviceProxy deviceProxy = null; var identity = Mock.Of(i => i.Id == "d1"); - var clientCredentials = Mock.Of(c => c.Identity == identity); var deviceListener = Mock.Of(); Mock.Get(deviceListener).Setup(d => d.BindDeviceProxy(It.IsAny())) .Callback(d => deviceProxy = d); - var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(clientCredentials.Identity) == Task.FromResult(deviceListener)); - - var amqpAuthentication = new AmqpAuthentication(true, Option.Some(clientCredentials)); - var cbsNode = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuthentication)); - var amqpConnection = Mock.Of(c => c.FindExtension() == cbsNode); - var connectionHandler = new ConnectionHandler(amqpConnection, connectionProvider); + var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(identity) == Task.FromResult(deviceListener)); + + var connectionHandler = new ClientConnectionHandler(identity, connectionProvider); IMessage receivedMessage = null; var c2DLinkHandler = new Mock(); @@ -151,17 +104,13 @@ public async Task RegisterModuleMessageSenderTest() // Arrange IDeviceProxy deviceProxy = null; var identity = Mock.Of(i => i.Id == "d1/m1"); - var clientCredentials = Mock.Of(c => c.Identity == identity); var deviceListener = Mock.Of(); Mock.Get(deviceListener).Setup(d => d.BindDeviceProxy(It.IsAny())) .Callback(d => deviceProxy = d); - var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(clientCredentials.Identity) == Task.FromResult(deviceListener)); + var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(identity) == Task.FromResult(deviceListener)); - var amqpAuthentication = new AmqpAuthentication(true, Option.Some(clientCredentials)); - var cbsNode = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuthentication)); - var amqpConnection = Mock.Of(c => c.FindExtension() == cbsNode); - var connectionHandler = new ConnectionHandler(amqpConnection, connectionProvider); + var connectionHandler = new ClientConnectionHandler(identity, connectionProvider); IMessage receivedMessage = null; var moduleMessageLinkHandler = new Mock(); @@ -191,17 +140,13 @@ public async Task RegisterMethodInvokerTest() // Arrange IDeviceProxy deviceProxy = null; var identity = Mock.Of(i => i.Id == "d1/m1"); - var clientCredentials = Mock.Of(c => c.Identity == identity); var deviceListener = Mock.Of(); Mock.Get(deviceListener).Setup(d => d.BindDeviceProxy(It.IsAny())) .Callback(d => deviceProxy = d); - var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(clientCredentials.Identity) == Task.FromResult(deviceListener)); + var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(identity) == Task.FromResult(deviceListener)); - var amqpAuthentication = new AmqpAuthentication(true, Option.Some(clientCredentials)); - var cbsNode = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuthentication)); - var amqpConnection = Mock.Of(c => c.FindExtension() == cbsNode); - var connectionHandler = new ConnectionHandler(amqpConnection, connectionProvider); + var connectionHandler = new ClientConnectionHandler(identity, connectionProvider); IMessage receivedMessage = null; var methodSendingLinkHandler = new Mock(); @@ -231,17 +176,13 @@ public async Task RegisterDesiredPropertiesUpdateSenderTest() // Arrange IDeviceProxy deviceProxy = null; var identity = Mock.Of(i => i.Id == "d1/m1"); - var clientCredentials = Mock.Of(c => c.Identity == identity); var deviceListener = Mock.Of(); Mock.Get(deviceListener).Setup(d => d.BindDeviceProxy(It.IsAny())) .Callback(d => deviceProxy = d); - var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(clientCredentials.Identity) == Task.FromResult(deviceListener)); + var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(identity) == Task.FromResult(deviceListener)); - var amqpAuthentication = new AmqpAuthentication(true, Option.Some(clientCredentials)); - var cbsNode = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuthentication)); - var amqpConnection = Mock.Of(c => c.FindExtension() == cbsNode); - var connectionHandler = new ConnectionHandler(amqpConnection, connectionProvider); + var connectionHandler = new ClientConnectionHandler(identity, connectionProvider); IMessage receivedMessage = null; var twinSendingLinkHandler = new Mock(); @@ -270,14 +211,10 @@ public async Task CloseOnRemovingAllLinksTest() var deviceListener = new Mock(); deviceListener.Setup(d => d.CloseAsync()).Returns(Task.CompletedTask); var identity = Mock.Of(i => i.Id == "d1/m1"); - var clientCredentials = Mock.Of(c => c.Identity == identity); - var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(clientCredentials.Identity) == Task.FromResult(deviceListener.Object)); + var connectionProvider = Mock.Of(c => c.GetDeviceListenerAsync(identity) == Task.FromResult(deviceListener.Object)); deviceListener.Setup(d => d.BindDeviceProxy(It.IsAny())); - var amqpAuthentication = new AmqpAuthentication(true, Option.Some(clientCredentials)); - var cbsNode = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuthentication)); - var amqpConnection = Mock.Of(c => c.FindExtension() == cbsNode); - var connectionHandler = new ConnectionHandler(amqpConnection, connectionProvider); + var connectionHandler = new ClientConnectionHandler(identity, connectionProvider); var eventsLinkHandler = Mock.Of(l => l.Type == LinkType.Events); string twinCorrelationId = Guid.NewGuid().ToString(); diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/DeviceBoundLinkHandlerTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/DeviceBoundLinkHandlerTest.cs index 31c2927ebd1..7e5047b258a 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/DeviceBoundLinkHandlerTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/DeviceBoundLinkHandlerTest.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test { using System; @@ -11,7 +12,6 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; - using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Azure.Devices.Edge.Util.Test.Common; using Moq; using Xunit; @@ -31,9 +31,10 @@ public void CreateTest() var requestUri = new Uri("amqps://foo.bar//devices/d1/messages/deviceBound"); var boundVariables = new Dictionary { { "deviceid", "d1" } }; var messageConverter = Mock.Of>(); + var identity = Mock.Of(d => d.Id == "d1"); // Act - ILinkHandler linkHandler = new DeviceBoundLinkHandler(amqpLink, requestUri, boundVariables, messageConverter); + ILinkHandler linkHandler = new DeviceBoundLinkHandler(identity, amqpLink, requestUri, boundVariables, connectionHandler, messageConverter); // Assert Assert.NotNull(linkHandler); @@ -54,9 +55,10 @@ public void CreateThrowsExceptionIfReceiverLinkTest() var requestUri = new Uri("amqps://foo.bar//devices/d1/messages/deviceBound"); var boundVariables = new Dictionary { { "deviceid", "d1" } }; var messageConverter = Mock.Of>(); + var identity = Mock.Of(d => d.Id == "d1"); // Act / Assert - Assert.Throws(() => new DeviceBoundLinkHandler(amqpLink, requestUri, boundVariables, messageConverter)); + Assert.Throws(() => new DeviceBoundLinkHandler(identity, amqpLink, requestUri, boundVariables, connectionHandler, messageConverter)); } [Fact] @@ -69,9 +71,15 @@ public async Task SendMessageTest() .Callback((m, s) => feedbackStatus = s) .Returns(Task.CompletedTask); AmqpMessage receivedAmqpMessage = null; - var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object) - && c.GetAmqpAuthentication() == Task.FromResult(new AmqpAuthentication(true, Option.Some(Mock.Of())))); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of( + c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var sendingLink = Mock.Of(l => l.Session == amqpSession && !l.IsReceiver && l.Settings == new AmqpLinkSettings() && l.State == AmqpObjectState.Opened); Mock.Get(sendingLink).Setup(s => s.SendMessageNoWait(It.IsAny(), It.IsAny>(), It.IsAny>())) @@ -80,13 +88,15 @@ public async Task SendMessageTest() var requestUri = new Uri("amqps://foo.bar/devices/d1"); var boundVariables = new Dictionary { { "deviceid", "d1" } }; var messageConverter = new AmqpMessageConverter(); + var identity = Mock.Of(d => d.Id == "d1"); - var sendingLinkHandler = new DeviceBoundLinkHandler(sendingLink, requestUri, boundVariables, messageConverter); + var sendingLinkHandler = new DeviceBoundLinkHandler(identity, sendingLink, requestUri, boundVariables, connectionHandler, messageConverter); var body = new byte[] { 0, 1, 2, 3 }; IMessage message = new EdgeMessage.Builder(body).Build(); var deliveryState = new Mock(new AmqpSymbol(""), AmqpConstants.AcceptedOutcome.DescriptorCode); - var delivery = Mock.Of(d => d.State == deliveryState.Object - && d.DeliveryTag == new ArraySegment(Guid.NewGuid().ToByteArray())); + var delivery = Mock.Of( + d => d.State == deliveryState.Object + && d.DeliveryTag == new ArraySegment(Guid.NewGuid().ToByteArray())); // Act await sendingLinkHandler.OpenAsync(TimeSpan.FromSeconds(5)); diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/EdgeHubSaslPlainAuthenticatorTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/EdgeHubSaslPlainAuthenticatorTest.cs index 71ee7709b99..ae697a11339 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/EdgeHubSaslPlainAuthenticatorTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/EdgeHubSaslPlainAuthenticatorTest.cs @@ -3,6 +3,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test { using System; + using System.Security.Principal; using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Util.Test.Common; @@ -100,11 +101,14 @@ public async void TestAuthSucceeds() Mock.Get(authenticator).Setup(a => a.AuthenticateAsync(clientCredentials)) .ReturnsAsync(true); - var principal = await saslAuthenticator.AuthenticateAsync(UserId, Password) as SaslPrincipal; + IPrincipal principal = await saslAuthenticator.AuthenticateAsync(UserId, Password); Assert.NotNull(principal); - Assert.NotNull(principal.Identity); - Assert.NotNull(principal.AmqpAuthentication); - Assert.Equal(identity, principal.AmqpAuthentication.ClientCredentials.OrDefault().Identity); + + var amqpAuthenticator = principal as IAmqpAuthenticator; + Assert.NotNull(amqpAuthenticator); + + bool isAuthenticated = await amqpAuthenticator.AuthenticateAsync("dev1/mod1"); + Assert.True(isAuthenticated); } } } diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/EventsLinkHandlerTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/EventsLinkHandlerTest.cs index c040b6e46ce..cfa4babaa81 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/EventsLinkHandlerTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/EventsLinkHandlerTest.cs @@ -33,9 +33,10 @@ public void CreateTest() var requestUri = new Uri("amqps://foo.bar/devices/d1/messages/events"); var boundVariables = new Dictionary { { "deviceid", "d1" } }; var messageConverter = Mock.Of>(); + var identity = Mock.Of(d => d.Id == "d1"); // Act - ILinkHandler linkHandler = new EventsLinkHandler(amqpLink, requestUri, boundVariables, messageConverter); + ILinkHandler linkHandler = new EventsLinkHandler(identity, amqpLink, requestUri, boundVariables, connectionHandler.Object, messageConverter); // Assert Assert.NotNull(linkHandler); @@ -48,15 +49,20 @@ public void CreateTest() public async Task SendMessageTest() { // Arrange - var identity = Mock.Of(i => i.Id == "d1"); - var amqpAuth = new AmqpAuthentication(true, Option.Some(Mock.Of(c => c.Identity == identity))); + var identity = Mock.Of(i => i.Id == "d1" && i.DeviceId == "d1"); IEnumerable receivedMessages = null; var deviceListener = Mock.Of(); Mock.Get(deviceListener).Setup(d => d.ProcessDeviceMessageBatchAsync(It.IsAny>())).Callback>(m => receivedMessages = m); - var connectionHandler = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuth) && c.GetDeviceListener() == Task.FromResult(deviceListener)); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of(c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var amqpLink = Mock.Of(l => l.Session == amqpSession && l.IsReceiver && l.Settings == new AmqpLinkSettings() && l.State == AmqpObjectState.Opened); @@ -76,7 +82,7 @@ public async Task SendMessageTest() amqpMessage.Properties.ContentType = "application/json"; amqpMessage.Properties.ContentEncoding = "utf-8"; - ILinkHandler linkHandler = new EventsLinkHandler(amqpLink, requestUri, boundVariables, messageConverter); + ILinkHandler linkHandler = new EventsLinkHandler(identity, amqpLink, requestUri, boundVariables, connectionHandler, messageConverter); // Act await linkHandler.OpenAsync(TimeSpan.FromSeconds(30)); @@ -114,7 +120,6 @@ public async Task SendMessageBatchTest() { // Arrange var identity = Mock.Of(i => i.Id == "d1"); - var amqpAuth = new AmqpAuthentication(true, Option.Some(Mock.Of(c => c.Identity == identity))); IEnumerable receivedMessages = null; var deviceListener = Mock.Of(); @@ -122,8 +127,14 @@ public async Task SendMessageBatchTest() .Callback>(m => receivedMessages = m) .Returns(Task.CompletedTask); - var connectionHandler = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuth) && c.GetDeviceListener() == Task.FromResult(deviceListener)); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of(c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var amqpLink = Mock.Of(l => l.Session == amqpSession && l.IsReceiver && l.Settings == new AmqpLinkSettings() && l.State == AmqpObjectState.Opened); @@ -151,7 +162,7 @@ public async Task SendMessageBatchTest() using (AmqpMessage amqpMessage = GetBatchedMessage(contents)) { amqpMessage.MessageFormat = AmqpConstants.AmqpBatchedMessageFormat; - ILinkHandler linkHandler = new EventsLinkHandler(amqpLink, requestUri, boundVariables, messageConverter); + ILinkHandler linkHandler = new EventsLinkHandler(identity, amqpLink, requestUri, boundVariables, connectionHandler, messageConverter); // Act await linkHandler.OpenAsync(TimeSpan.FromSeconds(30)); @@ -195,14 +206,19 @@ public async Task SendLargeMessageThrowsTest() // Arrange bool disposeMessageCalled = true; var identity = Mock.Of(i => i.Id == "d1"); - var amqpAuth = new AmqpAuthentication(true, Option.Some(Mock.Of(c => c.Identity == identity))); - + var deviceListener = Mock.Of(); Mock.Get(deviceListener).Setup(d => d.ProcessDeviceMessageBatchAsync(It.IsAny>())) .Returns(Task.CompletedTask); - var connectionHandler = Mock.Of(c => c.GetAmqpAuthentication() == Task.FromResult(amqpAuth) && c.GetDeviceListener() == Task.FromResult(deviceListener)); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of(c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var amqpLink = Mock.Of(l => l.Session == amqpSession && l.IsReceiver && l.Settings == new AmqpLinkSettings() && l.State == AmqpObjectState.Opened); @@ -220,7 +236,7 @@ public async Task SendLargeMessageThrowsTest() using (AmqpMessage amqpMessage = AmqpMessage.Create(new MemoryStream(new byte[800000]), false)) { amqpMessage.ApplicationProperties.Map["LargeProp"] = new int[600000]; - ILinkHandler linkHandler = new EventsLinkHandler(amqpLink, requestUri, boundVariables, messageConverter); + ILinkHandler linkHandler = new EventsLinkHandler(identity, amqpLink, requestUri, boundVariables, connectionHandler, messageConverter); // Act await linkHandler.OpenAsync(TimeSpan.FromSeconds(30)); diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/LinkHandlerProviderTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/LinkHandlerProviderTest.cs index 1b6a9d6418e..f4fd184de06 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/LinkHandlerProviderTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/LinkHandlerProviderTest.cs @@ -1,4 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test { using System; @@ -6,6 +7,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test using Microsoft.Azure.Amqp; using Microsoft.Azure.Devices.Edge.Hub.Amqp.LinkHandlers; using Microsoft.Azure.Devices.Edge.Hub.Core; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; using Microsoft.Azure.Devices.Edge.Util.Test.Common; using Moq; using Xunit; @@ -18,17 +20,59 @@ static IEnumerable GetLinkTypeTestData() yield return new object[] { "amqps://foo.bar/$cbs", true, LinkType.Cbs, new Dictionary() }; yield return new object[] { "amqps://foo.bar/$cbs", false, LinkType.Cbs, new Dictionary() }; yield return new object[] { "amqps://foo.bar//devices/device1/messages/events", true, LinkType.Events, new Dictionary { { "deviceid", "device1" } } }; - yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/messages/events", true, LinkType.Events, new Dictionary { { "deviceid", "device1" }, { "moduleid", "module1" } } }; - yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/messages/events", false, LinkType.ModuleMessages, new Dictionary { { "deviceid", "device1" }, { "moduleid", "module1" } } }; + yield return new object[] + { + "amqps://foo.bar/devices/device1/modules/module1/messages/events", true, LinkType.Events, new Dictionary + { + { "deviceid", "device1" }, + { "moduleid", "module1" } + } + }; + yield return new object[] + { + "amqps://foo.bar/devices/device1/modules/module1/messages/events", false, LinkType.ModuleMessages, new Dictionary + { + { "deviceid", "device1" }, + { "moduleid", "module1" } + } + }; yield return new object[] { "amqps://foo.bar/devices/device1/messages/deviceBound", false, LinkType.C2D, new Dictionary { { "deviceid", "device1" } } }; yield return new object[] { "amqps://foo.bar/devices/device1/methods/deviceBound", false, LinkType.MethodSending, new Dictionary { { "deviceid", "device1" } } }; - yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/methods/deviceBound", false, LinkType.MethodSending, new Dictionary { { "deviceid", "device1" }, { "moduleid", "module1" } } }; + yield return new object[] + { + "amqps://foo.bar/devices/device1/modules/module1/methods/deviceBound", false, LinkType.MethodSending, new Dictionary + { + { "deviceid", "device1" }, + { "moduleid", "module1" } + } + }; yield return new object[] { "amqps://foo.bar/devices/device1/methods/deviceBound", true, LinkType.MethodReceiving, new Dictionary { { "deviceid", "device1" } } }; - yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/methods/deviceBound", true, LinkType.MethodReceiving, new Dictionary { { "deviceid", "device1" }, { "moduleid", "module1" } } }; + yield return new object[] + { + "amqps://foo.bar/devices/device1/modules/module1/methods/deviceBound", true, LinkType.MethodReceiving, new Dictionary + { + { "deviceid", "device1" }, + { "moduleid", "module1" } + } + }; yield return new object[] { "amqps://foo.bar/devices/device1/twin", false, LinkType.TwinSending, new Dictionary { { "deviceid", "device1" } } }; - yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/twin", false, LinkType.TwinSending, new Dictionary { { "deviceid", "device1" }, { "moduleid", "module1" } } }; + yield return new object[] + { + "amqps://foo.bar/devices/device1/modules/module1/twin", false, LinkType.TwinSending, new Dictionary + { + { "deviceid", "device1" }, + { "moduleid", "module1" } + } + }; yield return new object[] { "amqps://foo.bar/devices/device1/twin", true, LinkType.TwinReceiving, new Dictionary { { "deviceid", "device1" } } }; - yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/twin", true, LinkType.TwinReceiving, new Dictionary { { "deviceid", "device1" }, { "moduleid", "module1" } } }; + yield return new object[] + { + "amqps://foo.bar/devices/device1/modules/module1/twin", true, LinkType.TwinReceiving, new Dictionary + { + { "deviceid", "device1" }, + { "moduleid", "module1" } + } + }; } [Theory] @@ -39,7 +83,8 @@ public void GetLinkTypeTest(string linkUri, bool isReceiver, LinkType expectedLi var messageConverter = Mock.Of>(); var twinMessageConverter = Mock.Of>(); var methodMessageConverter = Mock.Of>(); - var linkHandlerProvider = new LinkHandlerProvider(messageConverter, twinMessageConverter, methodMessageConverter); + var identityProvider = new IdentityProvider("foo.bar"); + var linkHandlerProvider = new LinkHandlerProvider(messageConverter, twinMessageConverter, methodMessageConverter, identityProvider); var amqpLink = Mock.Of(l => l.IsReceiver == isReceiver); var uri = new Uri(linkUri); @@ -72,7 +117,8 @@ public void GetInvalidLinkTypeTest(string linkUri, bool isReceiver) var messageConverter = Mock.Of>(); var twinMessageConverter = Mock.Of>(); var methodMessageConverter = Mock.Of>(); - var linkHandlerProvider = new LinkHandlerProvider(messageConverter, twinMessageConverter, methodMessageConverter); + var identityProvider = new IdentityProvider("foo.bar"); + var linkHandlerProvider = new LinkHandlerProvider(messageConverter, twinMessageConverter, methodMessageConverter, identityProvider); var amqpLink = Mock.Of(l => l.IsReceiver == isReceiver); var uri = new Uri(linkUri); @@ -83,40 +129,47 @@ public void GetInvalidLinkTypeTest(string linkUri, bool isReceiver) static IEnumerable GetLinkHandlerTestData() { - yield return new object[] { LinkType.Cbs, true, typeof(CbsLinkHandler) }; - yield return new object[] { LinkType.Cbs, false, typeof(CbsLinkHandler) }; - yield return new object[] { LinkType.Events, true, typeof(EventsLinkHandler) }; - yield return new object[] { LinkType.ModuleMessages, false, typeof(ModuleMessageLinkHandler) }; - yield return new object[] { LinkType.C2D, false, typeof(DeviceBoundLinkHandler) }; - yield return new object[] { LinkType.MethodReceiving, true, typeof(MethodReceivingLinkHandler) }; - yield return new object[] { LinkType.MethodSending, false, typeof(MethodSendingLinkHandler) }; - yield return new object[] { LinkType.TwinReceiving, true, typeof(TwinReceivingLinkHandler) }; - yield return new object[] { LinkType.TwinSending, false, typeof(TwinSendingLinkHandler) }; + yield return new object[] { "amqps://foo.bar/$cbs", true, typeof(CbsLinkHandler) }; + yield return new object[] { "amqps://foo.bar/$cbs", false, typeof(CbsLinkHandler) }; + yield return new object[] { "amqps://foo.bar//devices/device1/messages/events", true, typeof(EventsLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/messages/events", true, typeof(EventsLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/messages/events", false, typeof(ModuleMessageLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/messages/deviceBound", false, typeof(DeviceBoundLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/methods/deviceBound", false, typeof(MethodSendingLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/methods/deviceBound", false, typeof(MethodSendingLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/methods/deviceBound", true, typeof(MethodReceivingLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/methods/deviceBound", true, typeof(MethodReceivingLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/twin", false, typeof(TwinSendingLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/twin", false, typeof(TwinSendingLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/twin", true, typeof(TwinReceivingLinkHandler) }; + yield return new object[] { "amqps://foo.bar/devices/device1/modules/module1/twin", true, typeof(TwinReceivingLinkHandler) }; } [Theory] [MemberData(nameof(GetLinkHandlerTestData))] - public void GetLinkHandlerTest(LinkType linkType, bool isReceiver, Type expectedLinkHandlerType) + public void GetLinkHandlerTest(string url, bool isReceiver, Type expectedLinkHandlerType) { // Arrange var messageConverter = Mock.Of>(); var twinMessageConverter = Mock.Of>(); var methodMessageConverter = Mock.Of>(); - var linkHandlerProvider = new LinkHandlerProvider(messageConverter, twinMessageConverter, methodMessageConverter); + var identityProvider = new IdentityProvider("foo.bar"); + var linkHandlerProvider = new LinkHandlerProvider(messageConverter, twinMessageConverter, methodMessageConverter, identityProvider); - var uri = new Uri("amqps://foo.bar//abs/prq"); - var amqpConnection = Mock.Of(c => c.FindExtension() == Mock.Of()); + var uri = new Uri(url); + var amqpClientConnectionsHandler = Mock.Of(c => c.GetConnectionHandler(It.IsAny()) == Mock.Of()); + var amqpConnection = Mock.Of(c => c.FindExtension() == amqpClientConnectionsHandler); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); IAmqpLink amqpLink = isReceiver ? Mock.Of(l => l.IsReceiver && l.Session == amqpSession) : Mock.Of(l => !l.IsReceiver && l.Session == amqpSession) as IAmqpLink; - if (linkType == LinkType.Cbs) + if (url.Contains("$cbs")) { Mock.Get(amqpConnection).Setup(c => c.FindExtension()).Returns(Mock.Of()); } // Act - ILinkHandler linkHandler = linkHandlerProvider.GetLinkHandler(linkType, amqpLink, uri, new Dictionary()); + ILinkHandler linkHandler = linkHandlerProvider.Create(amqpLink, uri); // Assert Assert.NotNull(linkHandler); diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/ReceivingLinkHandlerTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/ReceivingLinkHandlerTest.cs index b91dad3e41b..72f5dd4a69b 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/ReceivingLinkHandlerTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/ReceivingLinkHandlerTest.cs @@ -11,7 +11,6 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; - using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Azure.Devices.Edge.Util.Test.Common; using Moq; using Xunit; @@ -24,9 +23,15 @@ public async Task ReceiveMessageTest() { // Arrange var deviceListener = new Mock(); - var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object) - && c.GetAmqpAuthentication() == Task.FromResult(new AmqpAuthentication(true, Option.Some(Mock.Of())))); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of( + c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var receivingLink = Mock.Of(l => l.Session == amqpSession && l.IsReceiver && l.Settings == new AmqpLinkSettings() && l.State == AmqpObjectState.Opened); @@ -35,9 +40,10 @@ public async Task ReceiveMessageTest() var messageConverter = new AmqpMessageConverter(); var body = new byte[] { 0, 1, 2, 3 }; AmqpMessage message = AmqpMessage.Create(new Data { Value = new ArraySegment(body) }); + var identity = Mock.Of(i => i.Id == "d1"); // Act - var receivingLinkHandler = new TestReceivingLinkHandler(receivingLink, requestUri, boundVariables, messageConverter); + var receivingLinkHandler = new TestReceivingLinkHandler(identity, receivingLink, requestUri, boundVariables, connectionHandler, messageConverter); await receivingLinkHandler.OpenAsync(Amqp.Constants.DefaultTimeout); await receivingLinkHandler.ProcessMessageAsync(message); @@ -50,8 +56,14 @@ public async Task ReceiveMessageTest() class TestReceivingLinkHandler : ReceivingLinkHandler { - public TestReceivingLinkHandler(IReceivingAmqpLink link, Uri requestUri, IDictionary boundVariables, IMessageConverter messageConverter) - : base(link, requestUri, boundVariables, messageConverter) + public TestReceivingLinkHandler( + IIdentity identity, + IReceivingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, + IMessageConverter messageConverter) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { } diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/SaslPrincipalTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/SaslPrincipalTest.cs index 7299b381f58..6dc91d5ffd7 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/SaslPrincipalTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/SaslPrincipalTest.cs @@ -4,7 +4,6 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test { using System; using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; - using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Azure.Devices.Edge.Util.Test.Common; using Moq; using Xunit; @@ -16,8 +15,8 @@ public class SaslPrincipalTest public void TestNullConstructorInputs() { var edgeHubIdentity = Mock.Of(i => i.Identity == Mock.Of(id => id.Id == "dev1/mod1")); - Assert.Throws(() => new SaslPrincipal(null)); - Assert.NotNull(new SaslPrincipal(new AmqpAuthentication(true, Option.Some(edgeHubIdentity)))); + Assert.Throws(() => new SaslPrincipal(false, null)); + Assert.NotNull(new SaslPrincipal(true, edgeHubIdentity)); } [Fact] @@ -25,7 +24,7 @@ public void TestNullConstructorInputs() public void TestIsInRoleThrows() { var edgeHubIdentity = Mock.Of(i => i.Identity == Mock.Of(id => id.Id == "dev1/mod1")); - var principal = new SaslPrincipal(new AmqpAuthentication(true, Option.Some(edgeHubIdentity))); + var principal = new SaslPrincipal(true, edgeHubIdentity); Assert.Throws(() => principal.IsInRole("boo")); } } diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/SendingLinkHandlerTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/SendingLinkHandlerTest.cs index 547d487ee4c..e1ce08983fb 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/SendingLinkHandlerTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/SendingLinkHandlerTest.cs @@ -12,7 +12,6 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; - using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Azure.Devices.Edge.Util.Test.Common; using Moq; using Xunit; @@ -30,9 +29,15 @@ public async Task SendMessageWithFeedbackTest() .Callback((m, s) => feedbackStatus = s) .Returns(Task.CompletedTask); AmqpMessage receivedAmqpMessage = null; - var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object) - && c.GetAmqpAuthentication() == Task.FromResult(new AmqpAuthentication(true, Option.Some(Mock.Of())))); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var identity = Mock.Of(i => i.Id == "d1"); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of(c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var amqpLinkSettings = new AmqpLinkSettings(); var sendingLink = Mock.Of(l => l.Session == amqpSession && !l.IsReceiver && l.Settings == amqpLinkSettings && l.State == AmqpObjectState.Opened); @@ -43,12 +48,13 @@ public async Task SendMessageWithFeedbackTest() var boundVariables = new Dictionary { { "deviceid", "d1" } }; var messageConverter = new AmqpMessageConverter(); - var sendingLinkHandler = new TestSendingLinkHandler(sendingLink, requestUri, boundVariables, messageConverter, QualityOfService.AtLeastOnce); + var sendingLinkHandler = new TestSendingLinkHandler(identity, sendingLink, requestUri, boundVariables, connectionHandler, messageConverter, QualityOfService.AtLeastOnce); var body = new byte[] { 0, 1, 2, 3 }; IMessage message = new EdgeMessage.Builder(body).Build(); var deliveryState = new Mock(new AmqpSymbol(""), AmqpConstants.AcceptedOutcome.DescriptorCode); - var delivery = Mock.Of(d => d.State == deliveryState.Object - && d.DeliveryTag == new ArraySegment(Guid.NewGuid().ToByteArray())); + var delivery = Mock.Of( + d => d.State == deliveryState.Object + && d.DeliveryTag == new ArraySegment(Guid.NewGuid().ToByteArray())); // Act await sendingLinkHandler.OpenAsync(TimeSpan.FromSeconds(5)); @@ -78,9 +84,15 @@ public async Task SendMessageWithFeedbackExactlyOnceModeTest() .Callback((m, s) => feedbackStatus = s) .Returns(Task.CompletedTask); AmqpMessage receivedAmqpMessage = null; - var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object) - && c.GetAmqpAuthentication() == Task.FromResult(new AmqpAuthentication(true, Option.Some(Mock.Of())))); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var identity = Mock.Of(i => i.Id == "d1"); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of(c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var amqpLinkSettings = new AmqpLinkSettings(); var sendingLink = Mock.Of(l => l.Session == amqpSession && !l.IsReceiver && l.Settings == amqpLinkSettings && l.State == AmqpObjectState.Opened); @@ -91,12 +103,13 @@ public async Task SendMessageWithFeedbackExactlyOnceModeTest() var boundVariables = new Dictionary { { "deviceid", "d1" } }; var messageConverter = new AmqpMessageConverter(); - var sendingLinkHandler = new TestSendingLinkHandler(sendingLink, requestUri, boundVariables, messageConverter, QualityOfService.ExactlyOnce); + var sendingLinkHandler = new TestSendingLinkHandler(identity, sendingLink, requestUri, boundVariables, connectionHandler, messageConverter, QualityOfService.ExactlyOnce); var body = new byte[] { 0, 1, 2, 3 }; IMessage message = new EdgeMessage.Builder(body).Build(); var deliveryState = new Mock(new AmqpSymbol(""), AmqpConstants.AcceptedOutcome.DescriptorCode); - var delivery = Mock.Of(d => d.State == deliveryState.Object - && d.DeliveryTag == new ArraySegment(Guid.NewGuid().ToByteArray())); + var delivery = Mock.Of( + d => d.State == deliveryState.Object + && d.DeliveryTag == new ArraySegment(Guid.NewGuid().ToByteArray())); // Act await sendingLinkHandler.OpenAsync(TimeSpan.FromSeconds(5)); @@ -120,13 +133,19 @@ public async Task SendMessageWithFeedbackExactlyOnceModeTest() public async Task SendMessageWithNoFeedbackTest() { // Arrange + var identity = Mock.Of(i => i.Id == "d1"); var deviceListener = new Mock(); deviceListener.Setup(d => d.ProcessMessageFeedbackAsync(It.IsAny(), It.IsAny())) .Returns(Task.CompletedTask); AmqpMessage receivedAmqpMessage = null; - var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object) - && c.GetAmqpAuthentication() == Task.FromResult(new AmqpAuthentication(true, Option.Some(Mock.Of())))); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of(c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var amqpLinkSettings = new AmqpLinkSettings(); var sendingLink = Mock.Of(l => l.Session == amqpSession && !l.IsReceiver && l.Settings == amqpLinkSettings && l.State == AmqpObjectState.Opened); @@ -138,7 +157,7 @@ public async Task SendMessageWithNoFeedbackTest() var boundVariables = new Dictionary { { "deviceid", "d1" } }; var messageConverter = new AmqpMessageConverter(); - var sendingLinkHandler = new TestSendingLinkHandler(sendingLink, requestUri, boundVariables, messageConverter, QualityOfService.AtMostOnce); + var sendingLinkHandler = new TestSendingLinkHandler(identity, sendingLink, requestUri, boundVariables, connectionHandler, messageConverter, QualityOfService.AtMostOnce); var body = new byte[] { 0, 1, 2, 3 }; IMessage message = new EdgeMessage.Builder(body).Build(); @@ -178,9 +197,15 @@ public static IEnumerable FeedbackStatusTestData() class TestSendingLinkHandler : SendingLinkHandler { - public TestSendingLinkHandler(ISendingAmqpLink link, Uri requestUri, - IDictionary boundVariables, IMessageConverter messageConverter, QualityOfService qualityOfService) - : base(link, requestUri, boundVariables, messageConverter) + public TestSendingLinkHandler( + IIdentity identity, + ISendingAmqpLink link, + Uri requestUri, + IDictionary boundVariables, + IConnectionHandler connectionHandler, + IMessageConverter messageConverter, + QualityOfService qualityOfService) + : base(identity, link, requestUri, boundVariables, connectionHandler, messageConverter) { this.QualityOfService = qualityOfService; } diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/TwinReceivingLinkHandlerTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/TwinReceivingLinkHandlerTest.cs index c8cf9ec1291..b0b841f3c85 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/TwinReceivingLinkHandlerTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.Amqp.Test/TwinReceivingLinkHandlerTest.cs @@ -10,7 +10,6 @@ namespace Microsoft.Azure.Devices.Edge.Hub.Amqp.Test using Microsoft.Azure.Devices.Edge.Hub.Core; using Microsoft.Azure.Devices.Edge.Hub.Core.Device; using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; - using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Azure.Devices.Edge.Util.Test.Common; using Moq; using Xunit; @@ -27,16 +26,23 @@ public async Task ProcessPutOperationMessageTest() deviceListener.Setup(d => d.AddDesiredPropertyUpdatesSubscription(It.IsAny())) .Callback(c => receivedCorrelationId = c) .Returns(Task.CompletedTask); - var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object) - && c.GetAmqpAuthentication() == Task.FromResult(new AmqpAuthentication(true, Option.Some(Mock.Of())))); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of( + c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var receivingLink = Mock.Of(l => l.Session == amqpSession && l.IsReceiver && l.Settings == new AmqpLinkSettings() && l.State == AmqpObjectState.Opened); var requestUri = new Uri("amqps://foo.bar/devices/d1/twin"); var boundVariables = new Dictionary { { "deviceid", "d1" } }; var messageConverter = Mock.Of>(); - var twinReceivingLinkHandler = new TwinReceivingLinkHandler(receivingLink, requestUri, boundVariables, messageConverter); + var identity = Mock.Of(i => i.Id == "d1"); + var twinReceivingLinkHandler = new TwinReceivingLinkHandler(identity, receivingLink, requestUri, boundVariables, connectionHandler, messageConverter); string correlationId = Guid.NewGuid().ToString(); AmqpMessage amqpMessage = AmqpMessage.Create(); @@ -62,16 +68,23 @@ public async Task ProcessDeleteOperationMessageTest() deviceListener.Setup(d => d.RemoveDesiredPropertyUpdatesSubscription(It.IsAny())) .Callback(c => receivedCorrelationId = c) .Returns(Task.CompletedTask); - var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object) - && c.GetAmqpAuthentication() == Task.FromResult(new AmqpAuthentication(true, Option.Some(Mock.Of())))); - var amqpConnection = Mock.Of(c => c.FindExtension() == connectionHandler); + var connectionHandler = Mock.Of(c => c.GetDeviceListener() == Task.FromResult(deviceListener.Object)); + var amqpAuthenticator = new Mock(); + amqpAuthenticator.Setup(c => c.AuthenticateAsync("d1")).ReturnsAsync(true); + Mock cbsNodeMock = amqpAuthenticator.As(); + ICbsNode cbsNode = cbsNodeMock.Object; + var amqpConnection = Mock.Of( + c => + c.FindExtension() == connectionHandler && + c.FindExtension() == cbsNode); var amqpSession = Mock.Of(s => s.Connection == amqpConnection); var receivingLink = Mock.Of(l => l.Session == amqpSession && l.IsReceiver && l.Settings == new AmqpLinkSettings() && l.State == AmqpObjectState.Opened); var requestUri = new Uri("amqps://foo.bar/devices/d1/twin"); var boundVariables = new Dictionary { { "deviceid", "d1" } }; var messageConverter = Mock.Of>(); - var twinReceivingLinkHandler = new TwinReceivingLinkHandler(receivingLink, requestUri, boundVariables, messageConverter); + var identity = Mock.Of(i => i.Id == "d1"); + var twinReceivingLinkHandler = new TwinReceivingLinkHandler(identity, receivingLink, requestUri, boundVariables, connectionHandler, messageConverter); string correlationId = Guid.NewGuid().ToString(); AmqpMessage amqpMessage = AmqpMessage.Create();