Skip to content
This repository was archived by the owner on Nov 1, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/Functions/AgentCommands.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
}
var nodeCommand = request.OkV;

var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId).FirstOrDefaultAsync();
var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId);
if (message != null) {
var command = message.Message;
var messageId = message.MessageId;
Expand Down
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/Functions/Node.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private async Async.Task<HttpResponseData> Get(HttpRequestData req) {

var (tasks, messages) = await (
_context.NodeTasksOperations.GetByMachineId(machineId).ToListAsync().AsTask(),
_context.NodeMessageOperations.GetMessage(machineId).ToListAsync().AsTask());
_context.NodeMessageOperations.GetMessages(machineId).ToListAsync().AsTask());

var commands = messages.Select(m => m.Message).ToList();
return await RequestHandling.Ok(req, NodeToNodeSearchResult(node with { Tasks = tasks, Messages = commands }));
Expand Down
12 changes: 9 additions & 3 deletions src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ApiService.OneFuzzLib.Orm;
using System.Threading.Tasks;
using ApiService.OneFuzzLib.Orm;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;

namespace Microsoft.OneFuzz.Service;
Expand All @@ -14,7 +15,9 @@ public NodeMessage(Guid machineId, NodeCommand message) : this(machineId, NewSor
};

public interface INodeMessageOperations : IOrm<NodeMessage> {
IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId);
IAsyncEnumerable<NodeMessage> GetMessages(Guid machineId);

Async.Task<NodeMessage?> GetMessage(Guid machineId);
Async.Task ClearMessages(Guid machineId);

Async.Task SendMessage(Guid machineId, NodeCommand message, string? messageId = null);
Expand All @@ -25,7 +28,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
public NodeMessageOperations(ILogTracer log, IOnefuzzContext context)
: base(log, context) { }

public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId)
public IAsyncEnumerable<NodeMessage> GetMessages(Guid machineId)
=> QueryAsync(Query.PartitionKey(machineId.ToString()));

public async Async.Task ClearMessages(Guid machineId) {
Expand All @@ -45,4 +48,7 @@ public async Async.Task SendMessage(Guid machineId, NodeCommand message, string?
_logTracer.WithHttpStatus(r.ErrorV).Error($"failed to insert message with id: {messageId:Tag:MessageId} for machine id: {machineId:Tag:MachineId} message: {message:Tag:Message}");
}
}

public async Task<NodeMessage?> GetMessage(Guid machineId)
=> await QueryAsync(Query.PartitionKey(machineId.ToString()), maxPerPage: 1).FirstOrDefaultAsync();
}
6 changes: 3 additions & 3 deletions src/ApiService/ApiService/onefuzzlib/orm/Orm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace ApiService.OneFuzzLib.Orm {
public interface IOrm<T> where T : EntityBase {
Task<TableClient> GetTableClient(string table, ResourceIdentifier? accountId = null);
IAsyncEnumerable<T> QueryAsync(string? filter = null);
IAsyncEnumerable<T> QueryAsync(string? filter = null, int? maxPerPage = null);

Task<T> GetEntityAsync(string partitionKey, string rowKey);
Task<ResultVoid<(HttpStatusCode Status, string Reason)>> Insert(T entity);
Expand Down Expand Up @@ -49,14 +49,14 @@ public Orm(ILogTracer logTracer, IOnefuzzContext context) {
_entityConverter = _context.EntityConverter;
}

public async IAsyncEnumerable<T> QueryAsync(string? filter = null) {
public async IAsyncEnumerable<T> QueryAsync(string? filter = null, int? maxPerPage = null) {
var tableClient = await GetTableClient(typeof(T).Name);

if (filter == "") {
filter = null;
}

await foreach (var x in tableClient.QueryAsync<TableEntity>(filter).Select(x => _entityConverter.ToRecord<T>(x))) {
await foreach (var x in tableClient.QueryAsync<TableEntity>(filter: filter, maxPerPage: maxPerPage).Select(x => _entityConverter.ToRecord<T>(x))) {
yield return x;
}
}
Expand Down
32 changes: 31 additions & 1 deletion src/ApiService/IntegrationTests/AgentCommandsTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Net;
using System;
using System.Net;
using FluentAssertions;
using IntegrationTests.Fakes;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.Functions;
Expand Down Expand Up @@ -50,4 +52,32 @@ public async Async.Task AgentAuthorization_IsAccepted() {
var result = await func.Run(TestHttpRequestData.Empty("GET"));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); // BadRequest due to no body, not Unauthorized
}

[Fact]
public async Async.Task AgentCommand_GetsCommand() {
var machineId = Guid.NewGuid();
var messageId = Guid.NewGuid().ToString();
var command = new NodeCommand {
Stop = new StopNodeCommand()
};
await Context.InsertAll(new[] {
new NodeMessage (
machineId,
messageId,
command
),
});

var commandRequest = new NodeCommandGet(machineId);
var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context);
var func = new AgentCommands(Logger, auth, Context);

var result = await func.Run(TestHttpRequestData.FromJson("GET", commandRequest));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);

var pendingNodeCommand = BodyAs<PendingNodeCommand>(result);
pendingNodeCommand.Envelope.Should().NotBeNull();
pendingNodeCommand.Envelope?.Command.Should().BeEquivalentTo(command);
pendingNodeCommand.Envelope?.MessageId.Should().Be(messageId);
}
}