Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] Streaming support in IAgent, support for streaming middleware plus pulling out pretty-print, user-input and function-call as middleware for easier testing #1656

Merged
merged 11 commits into from
Feb 28, 2024
Merged
Prev Previous commit
Next Next commit
pull out HumanInputMiddleware and FunctionCallMiddleware
  • Loading branch information
LittleLittleCloud committed Feb 13, 2024
commit b95652fac498c169f0ba8a6f8ad97ac15455255a
36 changes: 18 additions & 18 deletions dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,27 +66,27 @@ public void Use(IMiddleware middleware)
{
this.middlewares.Add(middleware);
}
}

internal class DelegateAgent : IAgent
{
private readonly IAgent innerAgent;
private readonly IMiddleware middleware;

public DelegateAgent(IMiddleware middleware, IAgent innerAgent)
private class DelegateAgent : IAgent
{
this.middleware = middleware;
this.innerAgent = innerAgent;
}
private readonly IAgent innerAgent;
private readonly IMiddleware middleware;

public DelegateAgent(IMiddleware middleware, IAgent innerAgent)
{
this.middleware = middleware;
this.innerAgent = innerAgent;
}

public string? Name { get => this.innerAgent.Name; }
public string? Name { get => this.innerAgent.Name; }

public Task<Message> GenerateReplyAsync(
IEnumerable<Message> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var context = new MiddlewareContext(messages, options);
return this.middleware.InvokeAsync(context, this.innerAgent, cancellationToken);
public Task<Message> GenerateReplyAsync(
IEnumerable<Message> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var context = new MiddlewareContext(messages, options);
return this.middleware.InvokeAsync(context, this.innerAgent, cancellationToken);
}
}
}
48 changes: 24 additions & 24 deletions dotnet/src/AutoGen/Core/Agent/MiddlewareStreamingAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,37 +59,37 @@ public void Use(Func<MiddlewareContext, IStreamingAgent, CancellationToken, Task
{
_middlewares.Add(new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func)));
}
}

internal class DelegateStreamingAgent : IStreamingAgent
{
private IStreamingMiddleware middleware;
private IStreamingAgent innerAgent;

public string? Name => innerAgent.Name;

public DelegateStreamingAgent(IStreamingMiddleware middleware, IStreamingAgent next)
private class DelegateStreamingAgent : IStreamingAgent
{
this.middleware = middleware;
this.innerAgent = next;
}
private IStreamingMiddleware middleware;
private IStreamingAgent innerAgent;

public async Task<Message> GenerateReplyAsync(IEnumerable<Message> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
var stream = await GenerateStreamingReplyAsync(messages, options, cancellationToken);
var result = default(Message);
public string? Name => innerAgent.Name;

await foreach (var message in stream.WithCancellation(cancellationToken))
public DelegateStreamingAgent(IStreamingMiddleware middleware, IStreamingAgent next)
{
result = message;
this.middleware = middleware;
this.innerAgent = next;
}

return result ?? throw new InvalidOperationException("No message returned from the streaming agent.");
}
public async Task<Message> GenerateReplyAsync(IEnumerable<Message> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
var stream = await GenerateStreamingReplyAsync(messages, options, cancellationToken);
var result = default(Message);

public Task<IAsyncEnumerable<Message>> GenerateStreamingReplyAsync(IEnumerable<Message> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
var context = new MiddlewareContext(messages, options);
return middleware.InvokeAsync(context, innerAgent, cancellationToken);
await foreach (var message in stream.WithCancellation(cancellationToken))
{
result = message;
}

return result ?? throw new InvalidOperationException("No message returned from the streaming agent.");
}

public Task<IAsyncEnumerable<Message>> GenerateStreamingReplyAsync(IEnumerable<Message> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
var context = new MiddlewareContext(messages, options);
return middleware.InvokeAsync(context, innerAgent, cancellationToken);
}
}
}
21 changes: 18 additions & 3 deletions dotnet/src/AutoGen/Core/Extension/MiddlewareExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ public static IAgent RegisterPreProcess(
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
/// <param name="agent"></param>
/// <param name="func"></param>
/// <returns></returns>
public static MiddlewareAgent RegisterMiddleware(
this IAgent agent,
Func<IEnumerable<Message>, GenerateReplyOptions?, IAgent, CancellationToken, Task<Message>> func,
Expand All @@ -106,4 +103,22 @@ public static MiddlewareAgent RegisterMiddleware(

return middlewareAgent;
}

/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareAgent RegisterMiddleware(
this IAgent agent,
IMiddleware middleware)
{
if (agent.Name == null)
{
throw new Exception("Agent name is null.");
}

var middlewareAgent = new MiddlewareAgent(agent);
middlewareAgent.Use(middleware);

return middlewareAgent;
}
}
17 changes: 17 additions & 0 deletions dotnet/src/AutoGen/Core/IAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// IAgent.cs

using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
Expand All @@ -25,6 +26,22 @@ public Task<Message> GenerateReplyAsync(

public class GenerateReplyOptions
{
public GenerateReplyOptions()
{
}

/// <summary>
/// Copy constructor
/// </summary>
/// <param name="other">other option to copy from</param>
public GenerateReplyOptions(GenerateReplyOptions other)
{
this.Temperature = other.Temperature;
this.MaxToken = other.MaxToken;
this.StopSequence = other.StopSequence?.Select(s => s)?.ToArray();
this.Functions = other.Functions?.Select(f => f)?.ToArray();
}

public float? Temperature { get; set; }

public int? MaxToken { get; set; }
Expand Down
45 changes: 45 additions & 0 deletions dotnet/src/AutoGen/Core/Middleware/DelegateMiddleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DelegateMiddleware.cs

using System;
using System.Threading;
using System.Threading.Tasks;

namespace AutoGen.Core.Middleware;

internal class DelegateMiddleware : IMiddleware
{
/// <summary>
/// middleware delegate. Call into the next function to continue the execution of the next middleware. Otherwise, short cut the middleware execution.
/// </summary>
/// <param name="cancellationToken">cancellation token</param>
public delegate Task<Message> MiddlewareDelegate(
MiddlewareContext context,
IAgent agent,
CancellationToken cancellationToken);

private readonly MiddlewareDelegate middlewareDelegate;

public DelegateMiddleware(string? name, Func<MiddlewareContext, IAgent, CancellationToken, Task<Message>> middlewareDelegate)
{
this.Name = name;
this.middlewareDelegate = async (context, agent, cancellationToken) =>
{
return await middlewareDelegate(context, agent, cancellationToken);
};
}

public string? Name { get; }

public Task<Message> InvokeAsync(
MiddlewareContext context,
IAgent agent,
CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var options = context.Options;

return this.middlewareDelegate(context, agent, cancellationToken);
}
}

38 changes: 38 additions & 0 deletions dotnet/src/AutoGen/Core/Middleware/DelegateStreamingMiddleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DelegateStreamingMiddleware.cs

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace AutoGen.Core.Middleware;

internal class DelegateStreamingMiddleware : IStreamingMiddleware
{
public delegate Task<IAsyncEnumerable<Message>> MiddlewareDelegate(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken);

private readonly MiddlewareDelegate middlewareDelegate;

public DelegateStreamingMiddleware(string? name, MiddlewareDelegate middlewareDelegate)
{
this.Name = name;
this.middlewareDelegate = middlewareDelegate;
}

public string? Name { get; }

public Task<IAsyncEnumerable<Message>> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var options = context.Options;

return this.middlewareDelegate(context, agent, cancellationToken);
}
}

83 changes: 83 additions & 0 deletions dotnet/src/AutoGen/Core/Middleware/FunctionCallMiddleware.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionCallMiddleware.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;

namespace AutoGen.Core.Middleware;

/// <summary>
/// The middleware that process function call message that both send to an agent or reply from an agent.
/// </summary>
public class FunctionCallMiddleware : IMiddleware
{
private readonly IEnumerable<FunctionDefinition>? functions;
private readonly IDictionary<string, Func<string, Task<string>>>? functionMap;
public FunctionCallMiddleware(
IEnumerable<FunctionDefinition>? functions = null,
IDictionary<string, Func<string, Task<string>>>? functionMap = null,
string? name = null)
{
this.Name = name ?? nameof(FunctionCallMiddleware);
this.functions = functions;
this.functionMap = functionMap;
}

public string? Name { get; }

public async Task<Message> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
// if the last message is a function call message, invoke the function and return the result instead of sending to the agent.
var lastMessage = context.Messages.Last();
if (lastMessage is not null && lastMessage is { FunctionName: string functionName, FunctionArguments: string functionArguments })
{
if (this.functionMap?.TryGetValue(functionName, out var func) is true)
{
var result = await func(functionArguments);
return new Message(role: Role.Function, content: result, from: lastMessage.From)
{
FunctionName = functionName,
FunctionArguments = functionArguments,
};
}
else
{
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";

return new Message(role: Role.Function, content: errorMessage, from: lastMessage.From)
{
FunctionName = functionName,
FunctionArguments = functionArguments,
};
}
}

// combine functions
var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions());
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
options.Functions = combinedFunctions?.ToArray();

var reply = await agent.GenerateReplyAsync(context.Messages, options, cancellationToken);

// if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent.
if (reply is { FunctionName: string fName, FunctionArguments: string fArgs })
{
if (this.functionMap?.TryGetValue(fName, out var func) is true)
{
var result = await func(fArgs);
return new Message(role: Role.Assistant, content: result, from: reply.From)
{
FunctionName = fName,
FunctionArguments = fArgs,
};
}
}

// for all other messages, just return the reply from the agent.
return reply;
}
}
Loading