Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Added additional data to auto function invocation filter context #7398

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
// history: if they don't want it, they can remove it, but this makes the data available,
// including metadata like usage.
chatRequest.AddMessage(chatChoice.Message!);
chatHistory.Add(this.ToChatMessageContent(modelId, responseData, chatChoice));

var chatMessageContent = this.ToChatMessageContent(modelId, responseData, chatChoice);
chatHistory.Add(chatMessageContent);

// We must send back a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
Expand Down Expand Up @@ -172,8 +174,9 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory, chatMessageContent)
dmytrostruk marked this conversation as resolved.
Show resolved Hide resolved
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down Expand Up @@ -404,8 +407,9 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory, chatHistory.Last())
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down
10 changes: 7 additions & 3 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,9 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chat)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chat, result)
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down Expand Up @@ -760,7 +761,9 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// Add the original assistant message to the chatOptions; this is required for the service
// to understand the tool call responses.
chatOptions.Messages.Add(GetRequestMessage(streamedRole ?? default, content, streamedName, toolCalls));
chat.Add(this.GetChatMessage(streamedRole ?? default, content, toolCalls, functionCallContents, metadata, streamedName));

var chatMessageContent = this.GetChatMessage(streamedRole ?? default, content, toolCalls, functionCallContents, metadata, streamedName);
chat.Add(chatMessageContent);

// Respond to each tooling request.
for (int toolCallIndex = 0; toolCallIndex < toolCalls.Length; toolCallIndex++)
Expand Down Expand Up @@ -805,8 +808,9 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chat)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chat, chatMessageContent)
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ public async Task FiltersAreExecutedCorrectlyOnStreamingAsync()
public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync()
{
// Arrange
var function = KernelFunctionFactory.CreateFromMethod(() => "Result");
var executionOrder = new List<string>();

var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1");
Expand Down Expand Up @@ -183,7 +182,6 @@ public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync()
public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming)
{
// Arrange
var function = KernelFunctionFactory.CreateFromMethod(() => "Result");
var executionOrder = new List<string>();

var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1");
Expand Down Expand Up @@ -617,6 +615,84 @@ public async Task FilterContextHasCancellationTokenAsync()
Assert.Equal(0, secondFunctionInvocations);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task FilterContextHasOperationRelatedInformationAsync(bool isStreaming)
{
// Arrange
List<string?> actualToolCallIds = [];
List<ChatMessageContent> actualChatMessageContents = [];

var function = KernelFunctionFactory.CreateFromMethod(() => "Result");

var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1");
var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2");

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);

var filter = new AutoFunctionInvocationFilter(async (context, next) =>
{
actualToolCallIds.Add(context.ToolCallId);
actualChatMessageContents.Add(context.ChatMessageContent);

await next(context);
});

var builder = Kernel.CreateBuilder();

builder.Plugins.Add(plugin);

builder.AddOpenAIChatCompletion(
modelId: "test-model-id",
apiKey: "test-api-key",
httpClient: this._httpClient);

builder.Services.AddSingleton<IAutoFunctionInvocationFilter>(filter);

var kernel = builder.Build();

var arguments = new KernelArguments(new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
});

// Act
if (isStreaming)
{
using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("filters_streaming_multiple_function_calls_test_response.txt")) };
using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_streaming_test_response.txt")) };

this._messageHandlerStub.ResponsesToReturn = [response1, response2];

await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", arguments))
{ }
}
else
{
using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("filters_multiple_function_calls_test_response.json")) };
using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) };

this._messageHandlerStub.ResponsesToReturn = [response1, response2];

await kernel.InvokePromptAsync("Test prompt", arguments);
}

// Assert
Assert.Equal(["tool-call-id-1", "tool-call-id-2"], actualToolCallIds);

foreach (var chatMessageContent in actualChatMessageContents)
{
var content = chatMessageContent as OpenAIChatMessageContent;

Assert.NotNull(content);

Assert.Equal("test-model-id", content.ModelId);
Assert.Equal(AuthorRole.Assistant, content.Role);
Assert.Equal(2, content.ToolCalls.Count);
}
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
"content": null,
"tool_calls": [
{
"id": "1",
"id": "tool-call-id-1",
"type": "function",
"function": {
"name": "MyPlugin-Function1",
"arguments": "{\n\"parameter\": \"function1-value\"\n}"
}
},
{
"id": "2",
"id": "tool-call-id-2",
"type": "function",
"function": {
"name": "MyPlugin-Function2",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin-Function1","arguments":"{\n\"parameter\": \"function1-value\"\n}"}}]},"finish_reason":"tool_calls"}]}
data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"tool-call-id-1","type":"function","function":{"name":"MyPlugin-Function1","arguments":"{\n\"parameter\": \"function1-value\"\n}"}}]},"finish_reason":"tool_calls"}]}

data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin-Function2","arguments":"{\n\"parameter\": \"function2-value\"\n}"}}]},"finish_reason":"tool_calls"}]}
data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"tool-call-id-2","type":"function","function":{"name":"MyPlugin-Function2","arguments":"{\n\"parameter\": \"function2-value\"\n}"}}]},"finish_reason":"tool_calls"}]}

data: [DONE]
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
<Right>lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.AutoFunctionInvocationContext.#ctor(Microsoft.SemanticKernel.Kernel,Microsoft.SemanticKernel.KernelFunction,Microsoft.SemanticKernel.FunctionResult,Microsoft.SemanticKernel.ChatCompletion.ChatHistory)</Target>
<Left>lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll</Left>
<Right>lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.BinaryContent.#ctor(System.Func{System.Threading.Tasks.Task{System.IO.Stream}},System.String,System.Object,System.Collections.Generic.IReadOnlyDictionary{System.String,System.Object})</Target>
Expand Down Expand Up @@ -71,6 +78,13 @@
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.AutoFunctionInvocationContext.#ctor(Microsoft.SemanticKernel.Kernel,Microsoft.SemanticKernel.KernelFunction,Microsoft.SemanticKernel.FunctionResult,Microsoft.SemanticKernel.ChatCompletion.ChatHistory)</Target>
<Left>lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll</Left>
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.BinaryContent.#ctor(System.Func{System.Threading.Tasks.Task{System.IO.Stream}},System.String,System.Object,System.Collections.Generic.IReadOnlyDictionary{System.String,System.Object})</Target>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,25 @@ public class AutoFunctionInvocationContext
/// <param name="function">The <see cref="KernelFunction"/> with which this filter is associated.</param>
/// <param name="result">The result of the function's invocation.</param>
/// <param name="chatHistory">The chat history associated with automatic function invocation.</param>
/// <param name="chatMessageContent">The chat message content associated with automatic function invocation.</param>
public AutoFunctionInvocationContext(
markwallace-microsoft marked this conversation as resolved.
Show resolved Hide resolved
Kernel kernel,
KernelFunction function,
FunctionResult result,
ChatHistory chatHistory)
ChatHistory chatHistory,
ChatMessageContent chatMessageContent)
{
Verify.NotNull(kernel);
Verify.NotNull(function);
Verify.NotNull(result);
Verify.NotNull(chatHistory);
Verify.NotNull(chatMessageContent);

this.Kernel = kernel;
this.Function = function;
this.Result = result;
this.ChatHistory = chatHistory;
this.ChatMessageContent = chatMessageContent;
}

/// <summary>
Expand Down Expand Up @@ -62,6 +66,16 @@ public AutoFunctionInvocationContext(
/// </summary>
public int FunctionCount { get; init; }

/// <summary>
/// The ID of the tool call.
/// </summary>
public string? ToolCallId { get; init; }
dmytrostruk marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// The chat message content associated with automatic function invocation.
/// </summary>
public ChatMessageContent ChatMessageContent { get; }

/// <summary>
/// Gets the <see cref="Microsoft.SemanticKernel.ChatCompletion.ChatHistory"/> associated with automatic function invocation.
/// </summary>
Expand Down
Loading