Skip to content

Commit

Permalink
.Net: Added additional data to auto function invocation filter context (
Browse files Browse the repository at this point in the history
microsoft#7398)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Resolves: microsoft#7208

Added following data to auto function invocation filter context:
- Tool Call ID
- `ChatMessageContent` associated with function invocation operation

Note: API compatibility suppression is required for
`AutoFunctionInvocationContext` constructor to make `ChatMessageContent`
required, since auto function invocation is happening only during chat
completion operation, so `ChatMessageContent` should be always
available. The reason why the constructor is public at the moment is
because it lives in `Microsoft.SemanticKernel.Abstractions` package,
while it is initialized in auto function calling logic in specific
connector. When function calling abstraction will be in place, the
constructor will be marked is `internal`, since
`AutoFunctionInvocationContext` should not be created on user side.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk authored Jul 23, 2024
1 parent 1652b9f commit 190b69b
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
// history: if they don't want it, they can remove it, but this makes the data available,
// including metadata like usage.
chatRequest.AddMessage(chatChoice.Message!);
chatHistory.Add(this.ToChatMessageContent(modelId, responseData, chatChoice));

var chatMessageContent = this.ToChatMessageContent(modelId, responseData, chatChoice);
chatHistory.Add(chatMessageContent);

// We must send back a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
Expand Down Expand Up @@ -172,8 +174,9 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory, chatMessageContent)
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down Expand Up @@ -404,8 +407,9 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory, chatHistory.Last())
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down
10 changes: 7 additions & 3 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,9 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chat)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chat, result)
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down Expand Up @@ -760,7 +761,9 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// Add the original assistant message to the chatOptions; this is required for the service
// to understand the tool call responses.
chatOptions.Messages.Add(GetRequestMessage(streamedRole ?? default, content, streamedName, toolCalls));
chat.Add(this.GetChatMessage(streamedRole ?? default, content, toolCalls, functionCallContents, metadata, streamedName));

var chatMessageContent = this.GetChatMessage(streamedRole ?? default, content, toolCalls, functionCallContents, metadata, streamedName);
chat.Add(chatMessageContent);

// Respond to each tooling request.
for (int toolCallIndex = 0; toolCallIndex < toolCalls.Length; toolCallIndex++)
Expand Down Expand Up @@ -805,8 +808,9 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chat)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chat, chatMessageContent)
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ public async Task FiltersAreExecutedCorrectlyOnStreamingAsync()
public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync()
{
// Arrange
var function = KernelFunctionFactory.CreateFromMethod(() => "Result");
var executionOrder = new List<string>();

var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1");
Expand Down Expand Up @@ -183,7 +182,6 @@ public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync()
public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming)
{
// Arrange
var function = KernelFunctionFactory.CreateFromMethod(() => "Result");
var executionOrder = new List<string>();

var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1");
Expand Down Expand Up @@ -617,6 +615,84 @@ public async Task FilterContextHasCancellationTokenAsync()
Assert.Equal(0, secondFunctionInvocations);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task FilterContextHasOperationRelatedInformationAsync(bool isStreaming)
{
// Arrange
List<string?> actualToolCallIds = [];
List<ChatMessageContent> actualChatMessageContents = [];

var function = KernelFunctionFactory.CreateFromMethod(() => "Result");

var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1");
var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2");

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);

var filter = new AutoFunctionInvocationFilter(async (context, next) =>
{
actualToolCallIds.Add(context.ToolCallId);
actualChatMessageContents.Add(context.ChatMessageContent);

await next(context);
});

var builder = Kernel.CreateBuilder();

builder.Plugins.Add(plugin);

builder.AddOpenAIChatCompletion(
modelId: "test-model-id",
apiKey: "test-api-key",
httpClient: this._httpClient);

builder.Services.AddSingleton<IAutoFunctionInvocationFilter>(filter);

var kernel = builder.Build();

var arguments = new KernelArguments(new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
});

// Act
if (isStreaming)
{
using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("filters_streaming_multiple_function_calls_test_response.txt")) };
using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_streaming_test_response.txt")) };

this._messageHandlerStub.ResponsesToReturn = [response1, response2];

await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", arguments))
{ }
}
else
{
using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("filters_multiple_function_calls_test_response.json")) };
using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) };

this._messageHandlerStub.ResponsesToReturn = [response1, response2];

await kernel.InvokePromptAsync("Test prompt", arguments);
}

// Assert
Assert.Equal(["tool-call-id-1", "tool-call-id-2"], actualToolCallIds);

foreach (var chatMessageContent in actualChatMessageContents)
{
var content = chatMessageContent as OpenAIChatMessageContent;

Assert.NotNull(content);

Assert.Equal("test-model-id", content.ModelId);
Assert.Equal(AuthorRole.Assistant, content.Role);
Assert.Equal(2, content.ToolCalls.Count);
}
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
"content": null,
"tool_calls": [
{
"id": "1",
"id": "tool-call-id-1",
"type": "function",
"function": {
"name": "MyPlugin-Function1",
"arguments": "{\n\"parameter\": \"function1-value\"\n}"
}
},
{
"id": "2",
"id": "tool-call-id-2",
"type": "function",
"function": {
"name": "MyPlugin-Function2",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin-Function1","arguments":"{\n\"parameter\": \"function1-value\"\n}"}}]},"finish_reason":"tool_calls"}]}
data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"tool-call-id-1","type":"function","function":{"name":"MyPlugin-Function1","arguments":"{\n\"parameter\": \"function1-value\"\n}"}}]},"finish_reason":"tool_calls"}]}

data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin-Function2","arguments":"{\n\"parameter\": \"function2-value\"\n}"}}]},"finish_reason":"tool_calls"}]}
data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"tool-call-id-2","type":"function","function":{"name":"MyPlugin-Function2","arguments":"{\n\"parameter\": \"function2-value\"\n}"}}]},"finish_reason":"tool_calls"}]}

data: [DONE]
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
<Right>lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.AutoFunctionInvocationContext.#ctor(Microsoft.SemanticKernel.Kernel,Microsoft.SemanticKernel.KernelFunction,Microsoft.SemanticKernel.FunctionResult,Microsoft.SemanticKernel.ChatCompletion.ChatHistory)</Target>
<Left>lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll</Left>
<Right>lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.BinaryContent.#ctor(System.Func{System.Threading.Tasks.Task{System.IO.Stream}},System.String,System.Object,System.Collections.Generic.IReadOnlyDictionary{System.String,System.Object})</Target>
Expand Down Expand Up @@ -71,6 +78,13 @@
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.AutoFunctionInvocationContext.#ctor(Microsoft.SemanticKernel.Kernel,Microsoft.SemanticKernel.KernelFunction,Microsoft.SemanticKernel.FunctionResult,Microsoft.SemanticKernel.ChatCompletion.ChatHistory)</Target>
<Left>lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll</Left>
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll</Right>
<IsBaselineSuppression>true</IsBaselineSuppression>
</Suppression>
<Suppression>
<DiagnosticId>CP0002</DiagnosticId>
<Target>M:Microsoft.SemanticKernel.BinaryContent.#ctor(System.Func{System.Threading.Tasks.Task{System.IO.Stream}},System.String,System.Object,System.Collections.Generic.IReadOnlyDictionary{System.String,System.Object})</Target>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,25 @@ public class AutoFunctionInvocationContext
/// <param name="function">The <see cref="KernelFunction"/> with which this filter is associated.</param>
/// <param name="result">The result of the function's invocation.</param>
/// <param name="chatHistory">The chat history associated with automatic function invocation.</param>
/// <param name="chatMessageContent">The chat message content associated with automatic function invocation.</param>
public AutoFunctionInvocationContext(
Kernel kernel,
KernelFunction function,
FunctionResult result,
ChatHistory chatHistory)
ChatHistory chatHistory,
ChatMessageContent chatMessageContent)
{
Verify.NotNull(kernel);
Verify.NotNull(function);
Verify.NotNull(result);
Verify.NotNull(chatHistory);
Verify.NotNull(chatMessageContent);

this.Kernel = kernel;
this.Function = function;
this.Result = result;
this.ChatHistory = chatHistory;
this.ChatMessageContent = chatMessageContent;
}

/// <summary>
Expand Down Expand Up @@ -62,6 +66,16 @@ public AutoFunctionInvocationContext(
/// </summary>
public int FunctionCount { get; init; }

/// <summary>
/// The ID of the tool call.
/// </summary>
public string? ToolCallId { get; init; }

/// <summary>
/// The chat message content associated with automatic function invocation.
/// </summary>
public ChatMessageContent ChatMessageContent { get; }

/// <summary>
/// Gets the <see cref="Microsoft.SemanticKernel.ChatCompletion.ChatHistory"/> associated with automatic function invocation.
/// </summary>
Expand Down

0 comments on commit 190b69b

Please sign in to comment.