Skip to content

Commit

Permalink
Drop messages when device is not in scope and auth mode is Scope (#4540)
Browse files Browse the repository at this point in the history
* Integrate src code
* Add scope validation in brokered cloud conn provider
  • Loading branch information
ancaantochi authored Mar 23, 2021
1 parent 90e4e87 commit 51ad827
Show file tree
Hide file tree
Showing 34 changed files with 1,577 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace Microsoft.Azure.Devices.Edge.Hub.CloudProxy
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.Azure.Devices.Client;
using Microsoft.Azure.Devices.Client.Exceptions;
using Microsoft.Azure.Devices.Client.Transport.Mqtt;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Cloud;
Expand Down Expand Up @@ -34,6 +35,9 @@ public class CloudConnectionProvider : ICloudConnectionProvider
readonly TimeSpan operationTimeout;
readonly IMetadataStore metadataStore;
readonly bool nestedEdgeEnabled;
readonly bool scopeAuthenticationOnly;
readonly bool trackDeviceState;

Option<IEdgeHub> edgeHub;

public CloudConnectionProvider(
Expand All @@ -51,6 +55,8 @@ public CloudConnectionProvider(
bool useServerHeartbeat,
Option<IWebProxy> proxy,
IMetadataStore metadataStore,
bool scopeAuthenticationOnly,
bool trackDeviceState,
bool nestedEdgeEnabled = true)
{
this.messageConverterProvider = Preconditions.CheckNotNull(messageConverterProvider, nameof(messageConverterProvider));
Expand All @@ -69,6 +75,8 @@ public CloudConnectionProvider(
this.operationTimeout = operationTimeout;
this.metadataStore = Preconditions.CheckNotNull(metadataStore, nameof(metadataStore));
this.nestedEdgeEnabled = nestedEdgeEnabled;
this.scopeAuthenticationOnly = scopeAuthenticationOnly;
this.trackDeviceState = trackDeviceState;
}

public void BindEdgeHub(IEdgeHub edgeHubInstance)
Expand Down Expand Up @@ -153,7 +161,12 @@ public async Task<Try<ICloudConnection>> Connect(IClientCredentials clientCreden
}
}

public async Task<Try<ICloudConnection>> Connect(IIdentity identity, Action<string, CloudConnectionStatus> connectionStatusChangedHandler)
public Task<Try<ICloudConnection>> Connect(IIdentity identity, Action<string, CloudConnectionStatus> connectionStatusChangedHandler) =>
this.trackDeviceState
? this.ConnectInternalWithDeviceStateTracking(identity, connectionStatusChangedHandler, false)
: this.ConnectInternal(identity, connectionStatusChangedHandler);

async Task<Try<ICloudConnection>> ConnectInternal(IIdentity identity, Action<string, CloudConnectionStatus> connectionStatusChangedHandler)
{
Preconditions.CheckNotNull(identity, nameof(identity));

Expand Down Expand Up @@ -204,11 +217,18 @@ public async Task<Try<ICloudConnection>> Connect(IIdentity identity, Action<stri
.GetOrElse(
async () =>
{
Events.ServiceIdentityNotFound(identity);
Option<IClientCredentials> clientCredentials = await this.credentialsCache.Get(identity);
return await clientCredentials
.Map(cc => this.Connect(cc, connectionStatusChangedHandler))
.GetOrElse(() => throw new InvalidOperationException($"Unable to find identity {identity.Id} in device scopes cache or credentials cache"));
// allow to use credential cache when auth mode is not Scope only (could be CloudAndScope or Cloud) or identity is for edgeHub
if (!this.scopeAuthenticationOnly || this.edgeHubIdentity.Id.Equals(identity.Id))
{
Events.ServiceIdentityNotFound(identity);
Option<IClientCredentials> clientCredentials = await this.credentialsCache.Get(identity);
var clientCredential = clientCredentials.Expect(() => new InvalidOperationException($"Unable to find identity {identity.Id} in device scopes cache or credentials cache"));
return await this.Connect(clientCredential, connectionStatusChangedHandler);
}
else
{
throw new InvalidOperationException($"Unable to find identity {identity.Id} in device scopes cache");
}
});
}
catch (Exception ex)
Expand All @@ -218,6 +238,100 @@ public async Task<Try<ICloudConnection>> Connect(IIdentity identity, Action<stri
}
}

async Task<Try<ICloudConnection>> ConnectInternalWithDeviceStateTracking(IIdentity identity, Action<string, CloudConnectionStatus> connectionStatusChangedHandler, bool refreshCachedIdentity)
{
Preconditions.CheckNotNull(identity, nameof(identity));

try
{
var cloudListener = new CloudListener(this.edgeHub.Expect(() => new InvalidOperationException("EdgeHub reference should not be null")), identity.Id);
string authChain = await this.deviceScopeIdentitiesCache.VerifyServiceIdentityAuthChainState(identity.Id, this.nestedEdgeEnabled, refreshCachedIdentity);

return await this.TryCreateCloudConnectionFromServiceIdentity(identity, connectionStatusChangedHandler, refreshCachedIdentity, cloudListener, authChain);
}
catch (DeviceInvalidStateException ex)
{
return await this.TryRecoverCloudConnection(identity, connectionStatusChangedHandler, refreshCachedIdentity, ex);
}
catch (Exception ex)
{
Events.ErrorCreatingCloudConnection(identity, ex);
return Try<ICloudConnection>.Failure(ex);
}
}

async Task<Try<ICloudConnection>> TryCreateCloudConnectionFromServiceIdentity(IIdentity identity, Action<string, CloudConnectionStatus> connectionStatusChangedHandler, bool refreshOutOfDateCache, CloudListener cloudListener, string authChain)
{
Events.CreatingCloudConnectionOnBehalfOf(identity);
ConnectionMetadata connectionMetadata = await this.metadataStore.GetMetadata(identity.Id);
string productInfo = connectionMetadata.EdgeProductInfo;
Option<string> modelId = connectionMetadata.ModelId;

ITransportSettings[] transportSettings = GetTransportSettings(
this.upstreamProtocol,
this.connectionPoolSize,
this.proxy,
this.useServerHeartbeat,
authChain);

try
{
ICloudConnection cc = await CloudConnection.Create(
identity,
connectionStatusChangedHandler,
transportSettings,
this.messageConverterProvider,
this.clientProvider,
cloudListener,
this.edgeHubTokenProvider,
this.idleTimeout,
this.closeOnIdleTimeout,
this.operationTimeout,
productInfo,
modelId);
Events.SuccessCreatingCloudConnection(identity);
return Try.Success(cc);
}
catch (UnauthorizedException ex) when (this.trackDeviceState)
{
return await this.TryRecoverCloudConnection(identity, connectionStatusChangedHandler, refreshOutOfDateCache, ex);
}
}

async Task<Try<ICloudConnection>> TryRecoverCloudConnection(IIdentity identity, Action<string, CloudConnectionStatus> connectionStatusChangedHandler, bool wasRefreshed, Exception ex)
{
try
{
Events.ErrorCreatingCloudConnection(identity, ex);
if (this.scopeAuthenticationOnly && !this.edgeHubIdentity.Id.Equals(identity.Id))
{
if (wasRefreshed)
{
Events.ErrorCreatingCloudConnection(identity, ex);
return Try<ICloudConnection>.Failure(ex);
}
else
{
// recover: try to update out of date cache and try again
return await this.ConnectInternalWithDeviceStateTracking(identity, connectionStatusChangedHandler, true);
}
}
else
{
// try with cached device credentials if auth mode is not Scope or identity is for edgeHub
Events.ServiceIdentityNotFound(identity);
Option<IClientCredentials> clientCredentials = await this.credentialsCache.Get(identity);
var clientCredential = clientCredentials.Expect(() => new InvalidOperationException($"Unable to find identity {identity.Id} in device scopes cache or credentials cache"));
return await this.Connect(clientCredential, connectionStatusChangedHandler);
}
}
catch (Exception e)
{
Events.ErrorCreatingCloudConnection(identity, e);
return Try<ICloudConnection>.Failure(e);
}
}

static ITransportSettings[] GetAmqpTransportSettings(TransportType type, int connectionPoolSize, Option<IWebProxy> proxy, bool useServerHeartbeat, string authChain)
{
var settings = new AmqpTransportSettings(type)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Azure.Devices.Edge.Hub.CloudProxy
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
Expand Down Expand Up @@ -58,9 +59,10 @@ public async Task<Option<ServiceIdentity>> GetServiceIdentity(string deviceId, s
scopeResult = Option.Maybe(res);
Events.IdentityScopeResultReceived(deviceId);
}
catch (DeviceScopeApiException ex) when (ex.StatusCode == HttpStatusCode.BadRequest)
catch (DeviceScopeApiException ex)
{
Events.BadRequestResult(deviceId, ex.StatusCode);
Events.ErrorRequestResult(deviceId, ex.StatusCode);
throw this.MapException(ex);
}

Option<ServiceIdentity> serviceIdentityResult =
Expand Down Expand Up @@ -126,9 +128,10 @@ public async Task<Option<ServiceIdentity>> GetServiceIdentity(string deviceId, s
scopeResult = Option.Maybe(res);
Events.IdentityScopeResultReceived(id);
}
catch (DeviceScopeApiException ex) when (ex.StatusCode == HttpStatusCode.BadRequest)
catch (DeviceScopeApiException ex)
{
Events.BadRequestResult(id, ex.StatusCode);
Events.ErrorRequestResult(deviceId, ex.StatusCode);
throw this.MapException(ex);
}

Option<ServiceIdentity> serviceIdentityResult =
Expand Down Expand Up @@ -166,6 +169,21 @@ public async Task<Option<ServiceIdentity>> GetServiceIdentity(string deviceId, s
return serviceIdentityResult;
}

Exception MapException(DeviceScopeApiException ex)
{
switch (ex.StatusCode)
{
case HttpStatusCode.Unauthorized:
case HttpStatusCode.Forbidden:
return new DeviceInvalidStateException($"Device not in scope: [{ex.StatusCode}: {ex.Message}].", ex);
case HttpStatusCode.BadRequest:
case HttpStatusCode.NotFound:
return new DeviceInvalidStateException($"Device not found: [{ex.StatusCode}: {ex.Message}].", ex);
default:
return new TimeoutException($"Request failed: [{ex.StatusCode}: {ex.Message}].", ex);
}
}

static class Events
{
const int IdStart = CloudProxyEventIds.ServiceProxy;
Expand Down Expand Up @@ -215,9 +233,9 @@ public static void ScopeNotFound(string id)
Log.LogWarning((int)EventIds.NoScopeFound, $"Device scope not found for {id}. Parent-child relationship is not set.");
}

public static void BadRequestResult(string id, HttpStatusCode statusCode)
public static void ErrorRequestResult(string id, HttpStatusCode statusCode)
{
Log.LogDebug((int)EventIds.ScopeResultReceived, $"Received scope result for {id} with status code {statusCode} indicating that {id} has been removed from the scope");
Log.LogDebug((int)EventIds.ScopeResultReceived, $"Received scope result for {id} with status code {statusCode}.");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ public static Option<string> GetAuthTarget(Option<string> authChain)
return authChainIds.FirstOption(id => true);
}

public static Option<string> GetAuthParent(Option<string> authChain)
{
return authChain.Match(
chain =>
{
string[] authChainIds = GetAuthChainIds(chain);
// The auth target is second element after the target
return authChainIds.Skip(1).FirstOption(id => true);
},
() => Option.None<string>());
}

public static Option<string> GetActorDeviceId(Option<string> authChain)
{
if (!authChain.HasValue)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,21 @@ public Option<IDeviceProxy> GetDeviceConnection(string id)

public async Task<Option<ICloudProxy>> GetCloudConnection(string id)
{
Try<ICloudProxy> cloudProxyTry = await this.TryGetCloudConnection(id);
Try<ICloudProxy> cloudProxyTry = await this.TryGetCloudConnectionInternal(id);
return cloudProxyTry
.Ok()
.Map(c => (ICloudProxy)new RetryingCloudProxy(id, () => this.TryGetCloudConnection(id), c));
.Map(c => (ICloudProxy)new RetryingCloudProxy(id, () => this.TryGetCloudConnectionInternal(id), c));
}

async Task<Try<ICloudProxy>> TryGetCloudConnection(string id)
public async Task<Try<ICloudProxy>> TryGetCloudConnection(string id)
{
Try<ICloudProxy> cloudProxyTry = await this.TryGetCloudConnectionInternal(id);
return cloudProxyTry.Success
? Try.Success((ICloudProxy)new RetryingCloudProxy(id, () => this.TryGetCloudConnectionInternal(id), cloudProxyTry.Value))
: cloudProxyTry;
}

async Task<Try<ICloudProxy>> TryGetCloudConnectionInternal(string id)
{
IIdentity identity = this.identityProvider.Create(Preconditions.CheckNonWhiteSpace(id, nameof(id)));
ConnectedDevice device = this.GetOrCreateConnectedDevice(identity);
Expand Down Expand Up @@ -213,7 +221,7 @@ public async Task<Try<ICloudProxy>> CreateCloudConnectionAsync(IClientCredential
Events.NewCloudConnection(credentials.Identity, newCloudConnection);
Try<ICloudProxy> cloudProxyTry = GetCloudProxyFromCloudConnection(newCloudConnection, credentials.Identity);
return cloudProxyTry.Success
? Try.Success((ICloudProxy)new RetryingCloudProxy(credentials.Identity.Id, () => this.TryGetCloudConnection(credentials.Identity.Id), cloudProxyTry.Value))
? Try.Success((ICloudProxy)new RetryingCloudProxy(credentials.Identity.Id, () => this.TryGetCloudConnectionInternal(credentials.Identity.Id), cloudProxyTry.Value))
: cloudProxyTry;
}

Expand All @@ -231,7 +239,7 @@ public async Task<Try<ICloudProxy>> GetOrCreateCloudConnectionAsync(IClientCrede
Events.GetCloudConnection(credentials.Identity, cloudConnectionTry);
Try<ICloudProxy> cloudProxyTry = GetCloudProxyFromCloudConnection(cloudConnectionTry, credentials.Identity);
return cloudProxyTry.Success
? Try.Success((ICloudProxy)new RetryingCloudProxy(credentials.Identity.Id, () => this.TryGetCloudConnection(credentials.Identity.Id), cloudProxyTry.Value))
? Try.Success((ICloudProxy)new RetryingCloudProxy(credentials.Identity.Id, () => this.TryGetCloudConnectionInternal(credentials.Identity.Id), cloudProxyTry.Value))
: cloudProxyTry;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.Azure.Devices.Edge.Hub.Core
{
using System;

public class DeviceInvalidStateException : Exception
{
public DeviceInvalidStateException()
{
}

public DeviceInvalidStateException(string message)
: base(message)
{
}

public DeviceInvalidStateException(string message, Exception innerException)
: base(message, innerException)
{
}
}
}
Loading

0 comments on commit 51ad827

Please sign in to comment.