Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[.Net] Mark Message as obsolete and add ToolCallAggregateMessage type #2716

Merged
merged 11 commits into from
May 21, 2024
Prev Previous commit
Next Next commit
fix tests
  • Loading branch information
LittleLittleCloud committed May 21, 2024
commit babba87de332ca004a52e1c17a6b9f6e6bbf0267
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/GroupChat/GroupChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public async Task<IAgent> SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumera
{string.Join(",", agentNames)}

Each message will start with 'From name:', e.g:
From admin:
From {agentNames.First()}:
//your message//.");

var conv = this.ProcessConversationsForRolePlay(this.initializeMessages, conversationHistory);
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Message/ToolCallMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public ToolCallMessage(IEnumerable<ToolCall> toolCalls, string? from = null)
public ToolCallMessage(string functionName, string functionArgs, string? from = null)
{
this.From = from;
this.ToolCalls = new List<ToolCall> { new ToolCall(functionName, functionArgs) };
this.ToolCalls = new List<ToolCall> { new ToolCall(functionName, functionArgs) { ToolCallId = functionName } };
}

public ToolCallMessage(ToolCallMessageUpdate update)
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/AutoGen.Core/Message/ToolCallResultMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public ToolCallResultMessage(IEnumerable<ToolCall> toolCalls, string? from = nul
public ToolCallResultMessage(string result, string functionName, string functionArgs, string? from = null)
{
this.From = from;
var toolCall = new ToolCall(functionName, functionArgs);
var toolCall = new ToolCall(functionName, functionArgs) { ToolCallId = functionName };
toolCall.Result = result;
this.ToolCalls = [toolCall];
}
Expand Down
9 changes: 8 additions & 1 deletion dotnet/src/AutoGen.Mistral/DTOs/ChatMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,25 @@ public enum RoleEnum

[JsonPropertyName("tool_calls")]
public List<FunctionContent>? ToolCalls { get; set; }

[JsonPropertyName("tool_call_id")]
public string? ToolCallId { get; set; }
}

public class FunctionContent
{
public FunctionContent(FunctionCall function)
public FunctionContent(string id, FunctionCall function)
{
this.Function = function;
this.Id = id;
}

[JsonPropertyName("function")]
public FunctionCall Function { get; set; }

[JsonPropertyName("id")]
public string Id { get; set; }

public class FunctionCall
{
public FunctionCall(string name, string arguments)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from
else if (finishReason == Choice.FinishReasonEnum.ToolCalls)
{
var functionContents = choice.Message?.ToolCalls ?? throw new ArgumentNullException("choice.Message.ToolCalls");
var toolCalls = functionContents.Select(f => new ToolCall(f.Function.Name, f.Function.Arguments)).ToList();
var toolCalls = functionContents.Select(f => new ToolCall(f.Function.Name, f.Function.Arguments) { ToolCallId = f.Id }).ToList();
return new ToolCallMessage(toolCalls, from: from.Name);
}
else
Expand Down Expand Up @@ -257,6 +257,7 @@ private IEnumerable<IMessage<ChatMessage>> ProcessToolCallResultMessage(ToolCall
var message = new ChatMessage(ChatMessage.RoleEnum.Tool, content: toolCall.Result)
{
Name = toolCall.FunctionName,
ToolCallId = toolCall.ToolCallId,
};

messages.Add(message);
Expand Down Expand Up @@ -305,10 +306,12 @@ private IEnumerable<IMessage<ChatMessage>> ProcessToolCallMessage(ToolCallMessag
// convert tool call message to chat message
var chatMessage = new ChatMessage(ChatMessage.RoleEnum.Assistant);
chatMessage.ToolCalls = new List<FunctionContent>();
foreach (var toolCall in toolCallMessage.ToolCalls)
for (var i = 0; i < toolCallMessage.ToolCalls.Count; i++)
{
var toolCall = toolCallMessage.ToolCalls[i];
var toolCallId = toolCall.ToolCallId ?? $"{toolCall.FunctionName}_{i}";
var functionCall = new FunctionContent.FunctionCall(toolCall.FunctionName, toolCall.FunctionArguments);
var functionContent = new FunctionContent(functionCall);
var functionContent = new FunctionContent(toolCallId, functionCall);
chatMessage.ToolCalls.Add(functionContent);
}

Expand Down
12 changes: 8 additions & 4 deletions dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,15 @@ public async Task MistralAgentFunctionCallMessageTest()
}
""";
var functionCallResult = await this.GetWeatherWrapper(weatherFunctionArgumets);

var toolCall = new ToolCall(this.GetWeatherFunctionContract.Name!, weatherFunctionArgumets)
{
ToolCallId = "012345678", // Mistral AI requires the tool call id to be a length of 9
Result = functionCallResult,
};
IMessage[] chatHistory = [
new TextMessage(Role.User, "what's the weather in Seattle?"),
new ToolCallMessage(this.GetWeatherFunctionContract.Name!, weatherFunctionArgumets, from: agent.Name),
new ToolCallResultMessage(functionCallResult, this.GetWeatherFunctionContract.Name!, weatherFunctionArgumets),
new ToolCallMessage([toolCall], from: agent.Name),
new ToolCallResultMessage([toolCall], weatherFunctionArgumets),
];

var reply = await agent.SendAsync(chatHistory: chatHistory);
Expand Down Expand Up @@ -152,7 +156,7 @@ public async Task MistralAgentFunctionCallMiddlewareMessageTest()

var question = new TextMessage(Role.User, "what's the weather in Seattle?");
var reply = await functionCallAgent.SendAsync(question);
reply.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
reply.Should().BeOfType<ToolCallAggregateMessage>();

// resend the reply to the same agent so it can generate the final response
// because the reply's from is the agent's name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public async Task ItRegisterKernelFunctionMiddlewareFromTestPluginTests()

var reply = await agent.SendAsync("what's the status of the light?");
reply.GetContent().Should().Be("off");
reply.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
if (reply is AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage)
reply.Should().BeOfType<ToolCallAggregateMessage>();
if (reply is ToolCallAggregateMessage aggregateMessage)
{
var toolCallMessage = aggregateMessage.Message1;
toolCallMessage.ToolCalls.Should().HaveCount(1);
Expand All @@ -44,8 +44,8 @@ public async Task ItRegisterKernelFunctionMiddlewareFromTestPluginTests()

reply = await agent.SendAsync("change the status of the light to on");
reply.GetContent().Should().Be("The status of the light is now on");
reply.Should().BeOfType<AggregateMessage<ToolCallMessage, ToolCallResultMessage>>();
if (reply is AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage1)
reply.Should().BeOfType<ToolCallAggregateMessage>();
if (reply is ToolCallAggregateMessage aggregateMessage1)
{
var toolCallMessage = aggregateMessage1.Message1;
toolCallMessage.ToolCalls.Should().HaveCount(1);
Expand Down