Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] Streaming support in IAgent, support for streaming middleware plus pulling out pretty-print, user-input and function-call as middleware for easier testing #1656

Merged
merged 11 commits into from
Feb 28, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public async Task CodeSnippet1()
// set human input mode to ALWAYS so that user always provide input
var userProxyAgent = new UserProxyAgent(
name: "user",
humanInputMode: ConversableAgent.HumanInputMode.ALWAYS)
humanInputMode: HumanInputMode.ALWAYS)
.RegisterPrintFormatMessageHook();

// start the conversation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ public async Task CodeSnippet1()
#endregion code_snippet_1

#region code_snippet_2
middlewareAgent.Use(async (messages, options, next, ct) =>
middlewareAgent.Use(async (messages, options, agent, ct) =>
{
var lastMessage = messages.Last();
lastMessage.Content = $"[middleware 0] {lastMessage.Content}";
return await next(messages, options, ct);
return await agent.GenerateReplyAsync(messages, options, ct);
});

reply = await middlewareAgent.SendAsync("Hello World");
Expand All @@ -46,11 +46,11 @@ public async Task CodeSnippet1()
reply.Content.Should().Be("[middleware 0] Hello World");
#endregion code_snippet_2_1
#region code_snippet_3
middlewareAgent.Use(async (messages, options, next, ct) =>
middlewareAgent.Use(async (messages, options, agent, ct) =>
{
var lastMessage = messages.Last();
lastMessage.Content = $"[middleware 1] {lastMessage.Content}";
return await next(messages, options, ct);
return await agent.GenerateReplyAsync(messages, options, ct);
});

reply = await middlewareAgent.SendAsync("Hello World");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public async Task CodeSnippet1()
// create a user proxy agent which always ask user for input
var agent = new UserProxyAgent(
name: "user",
humanInputMode: ConversableAgent.HumanInputMode.ALWAYS);
humanInputMode: HumanInputMode.ALWAYS);

await agent.SendAsync("hello");
#endregion code_snippet_1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public static async Task RunAsync()
// set human input mode to ALWAYS so that user always provide input
var userProxyAgent = new UserProxyAgent(
name: "user",
humanInputMode: ConversableAgent.HumanInputMode.ALWAYS)
humanInputMode: HumanInputMode.ALWAYS)
.RegisterPrintFormatMessageHook();

// start the conversation
Expand Down
169 changes: 34 additions & 135 deletions dotnet/src/AutoGen/Core/Agent/ConversableAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,31 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core.Middleware;
using AutoGen.OpenAI;

namespace AutoGen;

public class ConversableAgent : IAgent
public enum HumanInputMode
{
public enum HumanInputMode
{
NEVER = 0,
ALWAYS = 1,
AUTO = 2,
}
/// <summary>
/// NEVER prompt the user for input
/// </summary>
NEVER = 0,

/// <summary>
/// ALWAYS prompt the user for input
/// </summary>
ALWAYS = 1,

/// <summary>
/// prompt the user for input if the message is not a termination message
/// </summary>
AUTO = 2,
}

public class ConversableAgent : IAgent
{
private readonly IAgent? innerAgent;
private readonly string? defaultReply;
private readonly HumanInputMode humanInputMode;
Expand All @@ -40,7 +52,6 @@ public ConversableAgent(
this.humanInputMode = humanInputMode;
this.innerAgent = innerAgent;
this.IsTermination = isTermination;
this.defaultReply = defaultAutoReply;
this.systemMessage = systemMessage;
}

Expand All @@ -62,27 +73,6 @@ public ConversableAgent(
this.innerAgent = llmConfig?.ConfigList != null ? this.CreateInnerAgentFromConfigList(llmConfig) : null;
}

/// <summary>
/// Override this method to change the behavior of getting human input
/// </summary>
public virtual Task<string?> GetHumanInputAsync()
{
// first, write prompt, then read from stdin
var prompt = "Please give feedback: Press enter or type 'exit' to stop the conversation.";
Console.WriteLine(prompt);
var userInput = Console.ReadLine();
if (!string.IsNullOrEmpty(userInput) && userInput.ToLower() != "exit")
{
return Task.FromResult<string?>(userInput);
}
else
{
Console.WriteLine("Terminating the conversation");
userInput = GroupChatExtension.TERMINATE;
return Task.FromResult<string?>(userInput);
}
}

private IAgent? CreateInnerAgentFromConfigList(ConversableAgentConfig config)
{
IAgent? agent = null;
Expand Down Expand Up @@ -126,15 +116,12 @@ public async Task<Message> GenerateReplyAsync(
// first in, last out

// process default reply
IAgent agent = new DefaultReplyAgent(this.Name!, this.defaultReply ?? "Default reply is not set. Please pass a default reply to assistant agent");

// process inner agent
agent = agent.RegisterReply(async (messages, cancellationToken) =>
MiddlewareAgent agent;
if (this.innerAgent != null)
{
if (this.innerAgent != null)
agent = innerAgent.RegisterMiddleware(async (msgs, option, agent, ct) =>
{
// for every message, update message.From to inner agent's name if it is the name of this assistant agent
var updatedMessages = messages.Select(m =>
var updatedMessages = msgs.Select(m =>
{
if (m.From == this.Name)
{
Expand All @@ -148,110 +135,22 @@ public async Task<Message> GenerateReplyAsync(
}
});

var msg = await this.innerAgent.GenerateReplyAsync(updatedMessages, overrideOptions, cancellationToken);
msg.From = this.Name;

return msg;
}
else
{
return null;
}
});

// process human input
agent = agent.RegisterReply(async (messages, cancellationToken) =>
return await agent.GenerateReplyAsync(updatedMessages, option, ct);
});
}
else
{
async Task<Message> TakeUserInputAsync()
{
var input = await this.GetHumanInputAsync();
if (input != null)
{
return new Message(Role.Assistant, input, from: this.Name);
}
else
{
return new Message(Role.Assistant, string.Empty, from: this.Name);
}
}
agent = new MiddlewareAgent(new DefaultReplyAgent(this.Name!, this.defaultReply ?? "Default reply is not set. Please pass a default reply to assistant agent"));
}

if (this.humanInputMode == HumanInputMode.ALWAYS)
{
return await TakeUserInputAsync();
}
else if (this.humanInputMode == HumanInputMode.AUTO)
{
if (this.IsTermination != null && await this.IsTermination(messages, cancellationToken))
{
return await TakeUserInputAsync();
}
else
{
return null;
}
}
else
{
return null;
}
});
// process human input
var humanInputMiddleware = new HumanInputMiddleware(mode: this.humanInputMode, isTermination: this.IsTermination);
agent.Use(humanInputMiddleware);

// process function call
agent = agent.RegisterMiddleware(async (messages, option, innerAgent, cancellationToken) =>
{
if (this.functionMap != null &&
messages.Last()?.FunctionName is string functionName &&
messages.Last()?.FunctionArguments is string functionArguments &&
messages.Last()?.Content is null &&
this.functionMap.ContainsKey(functionName) &&
messages.Last().From != this.Name)
{
var reply = await this.ExecuteFunctionCallAsync(messages.Last(), cancellationToken);

// perform as a proxy to run function call from external agent, therefore, the reply should be from the last message's sender
reply.Role = Role.Function;

return reply;
}

var agentReply = await innerAgent.GenerateReplyAsync(messages, option, cancellationToken);
if (this.functionMap != null && agentReply.FunctionName is string && agentReply.FunctionArguments is string)
{
return await this.ExecuteFunctionCallAsync(agentReply, cancellationToken);
}
else
{
return agentReply;
}
});
var functionCallMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
agent.Use(functionCallMiddleware);

return await agent.GenerateReplyAsync(messages, overrideOptions, cancellationToken);
}

private async Task<Message> ExecuteFunctionCallAsync(Message message, CancellationToken _)
{
if (message.FunctionName is string functionName && message.FunctionArguments is string functionArguments && this.functionMap != null)
{
if (this.functionMap.TryGetValue(functionName, out var func))
{
var result = await func(functionArguments);
return new Message(Role.Assistant, result, from: this.Name)
{
FunctionName = functionName,
FunctionArguments = functionArguments,
};
}
else
{
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";
return new Message(Role.Assistant, errorMessage, from: this.Name)
{
FunctionName = functionName,
FunctionArguments = functionArguments,
};
}
}

throw new Exception("Function call is not available. Please pass a function map to assistant agent");
}
}
73 changes: 47 additions & 26 deletions dotnet/src/AutoGen/Core/Agent/MiddlewareAgent.cs
Original file line number Diff line number Diff line change
@@ -1,38 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MiddlewareAgent.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core.Middleware;

namespace AutoGen;

public delegate Task<Message> GenerateReplyDelegate(
IEnumerable<Message> messages,
GenerateReplyOptions? options,
CancellationToken cancellationToken);

/// <summary>
/// middleware delegate. Call into the next function to continue the execution of the next middleware. Otherwise, short cut the middleware execution.
/// </summary>
/// <param name="messages">messages to process</param>
/// <param name="options">options</param>
/// <param name="cancellationToken">cancellation token</param>
/// <param name="next">next middleware</param>
public delegate Task<Message> MiddlewareDelegate(
IEnumerable<Message> messages,
GenerateReplyOptions? options,
GenerateReplyDelegate next,
CancellationToken cancellationToken);

/// <summary>
/// An agent that allows you to add middleware and modify the behavior of an existing agent.
/// </summary>
public class MiddlewareAgent : IAgent
{
private readonly IAgent innerAgent;
private readonly List<MiddlewareDelegate> middlewares = new();
private readonly List<IMiddleware> middlewares = new();

/// <summary>
/// Create a new instance of <see cref="MiddlewareAgent"/>
Expand All @@ -52,20 +35,58 @@ public Task<Message> GenerateReplyAsync(
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var middleware = this.middlewares.Aggregate(
(GenerateReplyDelegate)this.innerAgent.GenerateReplyAsync,
(next, current) => (messages, options, cancellationToken) => current(messages, options, next, cancellationToken));
var agent = this.innerAgent;
foreach (var middleware in this.middlewares)
{
agent = new DelegateAgent(middleware, agent);
}

return middleware(messages, options, cancellationToken);
return agent.GenerateReplyAsync(messages, options, cancellationToken);
}

/// <summary>
/// Add a middleware to the agent. If multiple middlewares are added, they will be executed in the LIFO order.
/// Call into the next function to continue the execution of the next middleware.
/// Short cut middleware execution by not calling into the next function.
/// </summary>
public void Use(MiddlewareDelegate func)
public void Use(Func<IEnumerable<Message>, GenerateReplyOptions?, IAgent, CancellationToken, Task<Message>> func, string? middlewareName = null)
{
this.middlewares.Add(new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
{
return await func(context.Messages, context.Options, agent, cancellationToken);
}));
}

public void Use(Func<MiddlewareContext, IAgent, CancellationToken, Task<Message>> func, string? middlewareName = null)
{
this.middlewares.Add(new DelegateMiddleware(middlewareName, func));
}

public void Use(IMiddleware middleware)
{
this.middlewares.Add(middleware);
}

private class DelegateAgent : IAgent
{
this.middlewares.Add(func);
private readonly IAgent innerAgent;
private readonly IMiddleware middleware;

public DelegateAgent(IMiddleware middleware, IAgent innerAgent)
{
this.middleware = middleware;
this.innerAgent = innerAgent;
}

public string? Name { get => this.innerAgent.Name; }

public Task<Message> GenerateReplyAsync(
IEnumerable<Message> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var context = new MiddlewareContext(messages, options);
return this.middleware.InvokeAsync(context, this.innerAgent, cancellationToken);
}
}
}
Loading
Loading