Skip to content

Commit

Permalink
Added a separate message pump for messages from upstream (Azure#4481)
Browse files Browse the repository at this point in the history
The reason of the change is that when there is a big load of messages (from devices/modules), the processing queue can grow long. In this case, if there are messages from upstream (e.g twin result, direct method call), those messages get at the end of the queue and takes some time to process them. This change creates a separate queue for messages from upstream, increasing the reaction time of the system in certain scenarios.

The root of the change is in the method ForwardPublish(). This method gets calld by the mqtt library when a message arrives, and this puts the messages into a processing queue and returns immediately (giving back the control to the mqtt library as soon as possible). Now the change checks for messages from upstream (those all start with string "$downstream", and puts them into a separate queue.

This separate queue gets processed by a specific processing loop, implemented by UpstreamLoop(). Before this change, there was only a single loop, now there is an UpstreamLoop() and DownstreamLoop(), where UpstreamLoop() processes message coming from upstream. This loop uses a simple message handler class called "BrokeredCloudProxyDispatcher" as all messages from upstream goes through that class. DownstreamLoop() check all available message handlers (including BrokeredCloudProxyDispatcher), because it is not known which handler would process the messages.

Note, that it is intentional leaving the BrokeredCloudProxyDispatcher class for DownstreamLoop(), so if some message breaks the convention (not starting with $downstream), that still will be handled.
  • Loading branch information
vipeller authored Feb 25, 2021
1 parent 66655b2 commit 0e69854
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ public class MqttBrokerConnector : IMqttBrokerConnector

readonly TaskCompletionSource<bool> onConnectedTcs = new TaskCompletionSource<bool>();

Option<Channel<MqttPublishInfo>> publications;
Option<Task> forwardingLoop;
Option<Channel<MqttPublishInfo>> upstreamPublications;
Option<Channel<MqttPublishInfo>> downstreamPublications;
Option<Task> forwardingLoops;
Option<MqttClient> mqttClient;

AtomicBoolean isRetrying = new AtomicBoolean(false);
Expand Down Expand Up @@ -72,14 +73,7 @@ public async Task ConnectAsync(string serverAddress, int port)
client.MqttMsgPublished += this.ConfirmPublished;
client.MqttMsgPublishReceived += this.ForwardPublish;

this.publications = Option.Some(Channel.CreateUnbounded<MqttPublishInfo>(
new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = true
}));

this.forwardingLoop = Option.Some(this.StartForwardingLoop());
this.forwardingLoops = Option.Some(this.StartForwardingLoops());

// if ConnectAsync is supposed to manage starting it with broker down,
// put a loop here to keep trying - see 'TriggerReconnect' below
Expand All @@ -90,7 +84,7 @@ public async Task ConnectAsync(string serverAddress, int port)
client.MqttMsgPublished -= this.ConfirmPublished;
client.MqttMsgPublishReceived -= this.ForwardPublish;

await this.StopForwardingLoopAsync();
await this.StopForwardingLoopsAsync();

lock (this.guard)
{
Expand Down Expand Up @@ -149,7 +143,7 @@ await clientToStop.ForEachAsync(
}
}
await this.StopForwardingLoopAsync();
await this.StopForwardingLoopsAsync();
Events.Closed();
},
Expand Down Expand Up @@ -215,9 +209,20 @@ public async Task<bool> SendAsync(string topic, byte[] payload, bool retain = fa

void ForwardPublish(object sender, MqttMsgPublishEventArgs e)
{
var isWritten = this.publications.Match(
bool isWritten;
if (!string.IsNullOrEmpty(e.Topic) && e.Topic.StartsWith("$downstream/"))
{
// messages from upstream come with prefix downstream - because for the parent we are downstream
isWritten = this.upstreamPublications.Match(
channel => channel.Writer.TryWrite(new MqttPublishInfo(e.Topic, e.Message)),
() => false);
}
else
{
isWritten = this.downstreamPublications.Match(
channel => channel.Writer.TryWrite(new MqttPublishInfo(e.Topic, e.Message)),
() => false);
}

if (!isWritten)
{
Expand Down Expand Up @@ -305,71 +310,122 @@ void TriggerReconnect(object sender, EventArgs e)
TaskCreationOptions.LongRunning);
}

Task StartForwardingLoop()
Task StartForwardingLoops()
{
var loopTask = Task.Factory.StartNew(
async () =>
{
Events.ForwardingLoopStarted();
while (await this.publications.Expect(ChannelIsBroken).Reader.WaitToReadAsync())
{
var publishInfo = default(MqttPublishInfo);
try
{
publishInfo = await this.publications.Expect(ChannelIsBroken).Reader.ReadAsync();
}
catch (Exception e)
{
Events.FailedToForward(e);
continue;
}
var accepted = false;
foreach (var consumer in this.components.Consumers)
{
try
this.CreateMessageChannels();

var downstreamTask = Task.Factory.StartNew(this.DownstreamLoop, TaskCreationOptions.LongRunning);
var upstreamTask = Task.Factory.StartNew(this.UpstreamLoop, TaskCreationOptions.LongRunning);

return Task.WhenAll(downstreamTask, upstreamTask);
}

async Task StopForwardingLoopsAsync()
{
this.downstreamPublications.ForEach(channel => channel.Writer.Complete());
this.upstreamPublications.ForEach(channel => channel.Writer.Complete());

await this.forwardingLoops.ForEachAsync(loop => loop);

this.forwardingLoops = Option.None<Task>();
this.downstreamPublications = Option.None<Channel<MqttPublishInfo>>();
this.upstreamPublications = Option.None<Channel<MqttPublishInfo>>();
}

void CreateMessageChannels()
{
this.downstreamPublications = Option.Some(Channel.CreateUnbounded<MqttPublishInfo>(
new UnboundedChannelOptions
{
accepted = await consumer.HandleAsync(publishInfo);
if (accepted)
{
Events.MessageForwarded(consumer.GetType().Name, accepted, publishInfo.Topic, publishInfo.Payload.Length);
break;
}
}
catch (Exception e)
SingleReader = true,
SingleWriter = true
}));

this.upstreamPublications = Option.Some(Channel.CreateUnbounded<MqttPublishInfo>(
new UnboundedChannelOptions
{
Events.FailedToForward(e);
// Keep going with other consumers...
}
}
SingleReader = true,
SingleWriter = true
}));
}

if (!accepted)
{
Events.MessageNotForwarded(publishInfo.Topic, publishInfo.Payload.Length);
}
}
async Task DownstreamLoop()
{
Events.DownstreamForwardingLoopStarted();

Events.ForwardingLoopStopped();
},
TaskCreationOptions.LongRunning);
var channel = this.downstreamPublications.Expect(() => new Exception("No downstream channel is prepared to read"));
while (await channel.Reader.WaitToReadAsync())
{
var publishInfo = default(MqttPublishInfo);

return loopTask;
try
{
publishInfo = await channel.Reader.ReadAsync();
}
catch (Exception e)
{
Events.FailedToForwardDownstream(e);
continue;
}

Exception ChannelIsBroken()
{
return new Exception("Channel is broken, exiting forwarding loop by error");
var accepted = false;
foreach (var consumer in this.components.Consumers)
{
try
{
accepted = await consumer.HandleAsync(publishInfo);
if (accepted)
{
Events.MessageForwarded(consumer.GetType().Name, accepted, publishInfo.Topic, publishInfo.Payload.Length);
break;
}
}
catch (Exception e)
{
Events.FailedToForwardDownstream(e);
// Keep going with other consumers...
}
}

if (!accepted)
{
Events.MessageNotForwarded(publishInfo.Topic, publishInfo.Payload.Length);
}
}

Events.DownstreamForwardingLoopStopped();
}

async Task StopForwardingLoopAsync()
async Task UpstreamLoop()
{
this.publications.ForEach(channel => channel.Writer.Complete());
var upstreamDispatcher = this.components.Consumers.Where(c => c is BrokeredCloudProxyDispatcher).FirstOrDefault();
if (upstreamDispatcher == null)
{
throw new InvalidOperationException("There is no BrokeredCloudProxyDispatcher found in message consumer list");
}

Events.UpstreamForwardingLoopStarted();

var channel = this.upstreamPublications.Expect(() => new Exception("No upstream channel is prepared to read"));
while (await channel.Reader.WaitToReadAsync())
{
var publishInfo = default(MqttPublishInfo);

await this.forwardingLoop.ForEachAsync(loop => loop);
try
{
publishInfo = await channel.Reader.ReadAsync();

var accepted = await upstreamDispatcher.HandleAsync(publishInfo);
Events.MessageForwarded(upstreamDispatcher.GetType().Name, accepted, publishInfo.Topic, publishInfo.Payload.Length);
}
catch (Exception e)
{
Events.FailedToForwardUpstream(e);
// keep going
}
}

this.forwardingLoop = Option.None<Task>();
this.publications = Option.None<Channel<MqttPublishInfo>>();
Events.UpstreamForwardingLoopStopped();
}

// these are statics, so they don't use the state to acquire 'client' - making easier to handle parallel
Expand Down Expand Up @@ -483,8 +539,12 @@ enum EventIds
QosMismatch,
UnknownMessageId,
CouldNotForwardMessage,
ForwardingLoopStarted,
ForwardingLoopStopped,
DownstreamForwardingLoopStarted,
DownstreamForwardingLoopStopped,
UpstreamForwardingLoopStarted,
UpstreamForwardingLoopStopped,
FailedToForwardUpstream,
FailedToForwardDownstream,
MessageForwarded,
MessageNotForwarded,
FailedToForward,
Expand All @@ -503,11 +563,14 @@ enum EventIds
public static void QosMismatch() => Log.LogError((int)EventIds.QosMismatch, "MQTT server did not grant QoS for every requested subscription");
public static void UnknownMessageId(ushort id) => Log.LogError((int)EventIds.UnknownMessageId, "Unknown message id received : {0}", id);
public static void CouldNotForwardMessage(string topic, int len) => Log.LogWarning((int)EventIds.CouldNotForwardMessage, "Could not forward MQTT message from connector. Topic {0}, Msg. len {1} bytes", topic, len);
public static void ForwardingLoopStarted() => Log.LogInformation((int)EventIds.ForwardingLoopStarted, "Forwarding loop started");
public static void ForwardingLoopStopped() => Log.LogInformation((int)EventIds.ForwardingLoopStopped, "Forwarding loop stopped");
public static void DownstreamForwardingLoopStarted() => Log.LogInformation((int)EventIds.DownstreamForwardingLoopStarted, "Downstream forwarding loop started");
public static void DownstreamForwardingLoopStopped() => Log.LogInformation((int)EventIds.DownstreamForwardingLoopStopped, "Downstream forwarding loop stopped");
public static void UpstreamForwardingLoopStarted() => Log.LogInformation((int)EventIds.UpstreamForwardingLoopStarted, "Upstream forwarding loop started");
public static void UpstreamForwardingLoopStopped() => Log.LogInformation((int)EventIds.UpstreamForwardingLoopStopped, "Upstream forwarding loop stopped");
public static void MessageForwarded(string consumer, bool accepted, string topic, int len) => Log.LogDebug((int)EventIds.MessageForwarded, "Message forwarded to {0} and it {1}. Topic {2}, Msg. len {3} bytes", consumer, accepted ? "accepted" : "ignored", topic, len);
public static void MessageNotForwarded(string topic, int len) => Log.LogDebug((int)EventIds.MessageForwarded, "Message has not been forwarded to any consumers. Topic {0}, Msg. len {1} bytes", topic, len);
public static void FailedToForward(Exception e) => Log.LogError((int)EventIds.FailedToForward, e, "Failed to forward message.");
public static void FailedToForwardUpstream(Exception e) => Log.LogError((int)EventIds.FailedToForwardUpstream, e, "Failed to forward message from upstream.");
public static void FailedToForwardDownstream(Exception e) => Log.LogError((int)EventIds.FailedToForwardDownstream, e, "Failed to forward message from downstream.");
public static void CouldNotConnect() => Log.LogInformation((int)EventIds.CouldNotConnect, "Could not connect to MQTT Broker, possibly it is not running. To disable MQTT Broker Connector, please set 'mqttBrokerSettings__enabled' environment variable to 'false'");
public static void TimeoutReceivingSubAcks(Exception e) => Log.LogError((int)EventIds.TimeoutReceivingSubAcks, e, "MQTT Broker has not acknowledged subscriptions in time");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ namespace Microsoft.Azure.Devices.Edge.Hub.MqttBrokerAdapter.Test
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Devices.Edge.Hub.Core;
using Microsoft.Azure.Devices.Edge.Hub.Core.Device;
using Microsoft.Azure.Devices.Edge.Hub.Core.Identity;
using Microsoft.Azure.Devices.Edge.Util;
using Microsoft.Azure.Devices.Edge.Util.Test.Common;
using Moq;
Expand Down Expand Up @@ -374,6 +377,50 @@ public async Task OfflineSendGetSentAfterReconnect()
Assert.Equal("hoo", Encoding.ASCII.GetString(broker.Publications.First().Item2));
}
}

[Fact]
public async Task MessagesFromUpstreamHandledOnSeparatePath()
{
using var broker = new MiniMqttServer();

var upstreamMilestone = new SemaphoreSlim(0, 1);
var downstreamMilestone = new SemaphoreSlim(0, 1);

var edgeHubStub = new EdgeHubStub(() => upstreamMilestone.Release());

var upstreamDispatcher = new BrokeredCloudProxyDispatcher();
upstreamDispatcher.BindEdgeHub(edgeHubStub);

var consumers = new IMessageConsumer[]
{
new ConsumerStub
{
ShouldHandle = true,
Handler = _ =>
{
upstreamMilestone.Wait(); // this blocks the pump for downstream messages
downstreamMilestone.Release();
}
},

upstreamDispatcher
};

using var sut = new ConnectorBuilder()
.WithConsumers(consumers)
.Build();

await sut.ConnectAsync(HOST, broker.Port);

// handled by downstream pump and the pump gets blocked
await broker.PublishAsync("boo", Encoding.ASCII.GetBytes("hoo"));

// handled by upstream pump that let's downstream pump getting unblocked
await broker.PublishAsync("$downstream/device_1/methods/post/foo/?$rid=123", new byte[0]);

// check if downsteam pump got unblocked
Assert.True(await downstreamMilestone.WaitAsync(TimeSpan.FromSeconds(5)));
}
}

class ProducerStub : IMessageProducer
Expand Down Expand Up @@ -457,6 +504,51 @@ public DisposableMqttBrokerConnector Build()
}
}

class EdgeHubStub : IEdgeHub
{
Action whenCalled;

public EdgeHubStub(Action whenCalled)
{
this.whenCalled = whenCalled;
}

public void Dispose()
{
}

public IDeviceScopeIdentitiesCache GetDeviceScopeIdentitiesCache() => throw new NotImplementedException();
public string GetEdgeDeviceId() => "x";

public Task<IMessage> GetTwinAsync(string id)
{
this.whenCalled();
return Task.FromResult(new EdgeMessage(new byte[0], new Dictionary<string, string>(), new Dictionary<string, string>()) as IMessage);
}

public Task<DirectMethodResponse> InvokeMethodAsync(string id, DirectMethodRequest methodRequest)
{
this.whenCalled();
return Task.FromResult(new DirectMethodResponse("boo", new byte[0], 200));
}

public Task ProcessDeviceMessage(IIdentity identity, IMessage message) => WhenCalled();
public Task ProcessDeviceMessageBatch(IIdentity identity, IEnumerable<IMessage> message) => WhenCalled();
public Task AddSubscription(string id, DeviceSubscription deviceSubscription) => WhenCalled();
public Task ProcessSubscriptions(string id, IEnumerable<(DeviceSubscription, bool)> subscriptions) => WhenCalled();
public Task RemoveSubscription(string id, DeviceSubscription deviceSubscription) => WhenCalled();
public Task RemoveSubscriptions(string id) => WhenCalled();
public Task SendC2DMessageAsync(string id, IMessage message) => WhenCalled();
public Task UpdateDesiredPropertiesAsync(string id, IMessage twinCollection) => WhenCalled();
public Task UpdateReportedPropertiesAsync(IIdentity identity, IMessage reportedPropertiesMessage) => WhenCalled();

Task WhenCalled()
{
this.whenCalled();
return Task.CompletedTask;
}
}

class MiniMqttServer : IDisposable
{
CancellationTokenSource cts;
Expand Down

0 comments on commit 0e69854

Please sign in to comment.