Skip to content

Commit

Permalink
.Net BugFix - Using StandardizedPrompt With Kernel/Function InvokeAsy…
Browse files Browse the repository at this point in the history
…nc with ChatCompletions was not being parsed correctly. (#4025)

### Motivation and Context

Using StandardizedPrompt With Kernel/Function InvokeAsync with
ChatCompletions was not being parsed correctly.

Small bugFix in the parameter name `chat` to `chatHistory` in
IChatCompletion interfaces.
Resolves #3960
  • Loading branch information
RogerBarreto authored Dec 6, 2023
1 parent 5f047e5 commit 136e7c2
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ public AzureOpenAIChatCompletionService(
public IReadOnlyDictionary<string, object?> Attributes => this._core.Attributes;

/// <inheritdoc/>
public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chat, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._core.GetChatMessageContentsAsync(chat, executionSettings, kernel, cancellationToken);
public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._core.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);

/// <inheritdoc/>
public IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ public OpenAIChatCompletionService(
public IReadOnlyDictionary<string, object?> Attributes => this.Attributes;

/// <inheritdoc/>
public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chat, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._core.GetChatMessageContentsAsync(chat, executionSettings, kernel, cancellationToken);
public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._core.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);

/// <inheritdoc/>
public IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ public AzureOpenAIChatCompletionWithDataService(
public IReadOnlyDictionary<string, object?> Attributes => this._attributes;

/// <inheritdoc/>
public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chat, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this.InternalGetChatMessageContentsAsync(chat, executionSettings, kernel, cancellationToken);
public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this.InternalGetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);

/// <inheritdoc/>
public IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,6 @@ public async Task MultipleServiceLoadPromptConfigTestAsync()
// Assert
Assert.Contains("Pike Place", azureResult.GetValue<string>(), StringComparison.OrdinalIgnoreCase);
}

#region internals

private readonly XunitLogger<Kernel> _logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ public interface IChatCompletionService : IAIService
/// <remarks>
/// This should be used when the settings request for more than one choice.
/// </remarks>
/// <param name="chat">The chat history context.</param>
/// <param name="chatHistory">The chat history context.</param>
/// <param name="executionSettings">The AI execution settings (optional).</param>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of different chat results generated by the remote model</returns>
Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(
ChatHistory chat,
ChatHistory chatHistory,
PromptExecutionSettings? executionSettings = null,
Kernel? kernel = null,
CancellationToken cancellationToken = default);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class StreamingChatMessageContent : StreamingContentBase
/// <param name="encoding">Encoding of the chat</param>
/// <param name="metadata">Additional metadata</param>
[JsonConstructor]
protected StreamingChatMessageContent(AuthorRole? role, string? content, object? innerContent, int choiceIndex = 0, string? modelId = null, Encoding? encoding = null, IDictionary<string, object?>? metadata = null) : base(innerContent, choiceIndex, modelId, metadata)
public StreamingChatMessageContent(AuthorRole? role, string? content, object? innerContent = null, int choiceIndex = 0, string? modelId = null, Encoding? encoding = null, IDictionary<string, object?>? metadata = null) : base(innerContent, choiceIndex, modelId, metadata)
{
this.Role = role;
this.Content = content;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.AI.ChatCompletion;

namespace Microsoft.SemanticKernel.AI.TextGeneration;

Expand All @@ -12,7 +16,7 @@ namespace Microsoft.SemanticKernel.AI.TextGeneration;
public static class TextGenerationExtensions
{
/// <summary>
/// Get a single text completion result for the prompt and settings.
/// Get a single text generation result for the prompt and settings.
/// </summary>
/// <param name="textGenerationService">Text generation service</param>
/// <param name="prompt">The standardized prompt input.</param>
Expand All @@ -28,4 +32,72 @@ public static async Task<TextContent> GetTextContentAsync(
CancellationToken cancellationToken = default)
=> (await textGenerationService.GetTextContentsAsync(prompt, executionSettings, kernel, cancellationToken).ConfigureAwait(false))
.Single();

/// <summary>
/// Get a single text generation result for the standardized prompt and settings.
/// </summary>
/// <param name="textGenerationService">Text generation service</param>
/// <param name="prompt">The standardized prompt input.</param>
/// <param name="executionSettings">The AI execution settings (optional).</param>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>List of different text results generated by the remote model</returns>
internal static async Task<TextContent> GetTextContentWithDefaultParserAsync(
this ITextGenerationService textGenerationService,
string prompt,
PromptExecutionSettings? executionSettings = null,
Kernel? kernel = null,
CancellationToken cancellationToken = default)
{
if (textGenerationService is IChatCompletionService chatCompletion
&& XmlPromptParser.TryParse(prompt!, out var nodes)
&& ChatPromptParser.TryParse(nodes, out var chatHistory))
{
var chatMessage = await chatCompletion.GetChatMessageContentAsync(chatHistory, executionSettings, kernel, cancellationToken).ConfigureAwait(false);
return new TextContent(chatMessage.Content, chatMessage.ModelId, chatMessage.InnerContent, chatMessage.Encoding, chatMessage.Metadata);
}

// When using against text generations, the prompt will be used as is.
return await textGenerationService.GetTextContentAsync(prompt, executionSettings, kernel, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Get streaming results for the standardized prompt using the specified settings.
/// Each modality may support for different types of streaming contents.
/// </summary>
/// <remarks>
/// Usage of this method with value types may be more efficient if the connector supports it.
/// </remarks>
/// <exception cref="NotSupportedException">Throws if the specified type is not the same or fail to cast</exception>
/// <param name="textGenerationService">Text generation service</param>
/// <param name="prompt">The standardized prompt to complete.</param>
/// <param name="executionSettings">The AI execution settings (optional).</param>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Streaming list of different generation streaming string updates generated by the remote model</returns>
internal static async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsWithDefaultParserAsync(
this ITextGenerationService textGenerationService,
string prompt,
PromptExecutionSettings? executionSettings = null,
Kernel? kernel = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (textGenerationService is IChatCompletionService chatCompletion
&& XmlPromptParser.TryParse(prompt!, out var nodes)
&& ChatPromptParser.TryParse(nodes, out var chatHistory))
{
await foreach (var chatMessage in chatCompletion.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken))
{
yield return new StreamingTextContent(chatMessage.Content, chatMessage.ChoiceIndex, chatMessage.ModelId, chatMessage, chatMessage.Encoding, chatMessage.Metadata);
}

yield break;
}

// When using against text generations, the prompt will be used as is.
await foreach (var textChunk in textGenerationService.GetStreamingTextContentsAsync(prompt, executionSettings, kernel, cancellationToken))
{
yield return textChunk;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ protected override async ValueTask<FunctionResult> InvokeCoreAsync(
throw new OperationCanceledException($"A {nameof(Kernel)}.{nameof(Kernel.PromptRendered)} event handler requested cancellation before function invocation.");
}

var textContent = await textGeneration.GetTextContentAsync(renderedPrompt, arguments.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false);
var textContent = await textGeneration.GetTextContentWithDefaultParserAsync(renderedPrompt, arguments.ExecutionSettings, kernel, cancellationToken).ConfigureAwait(false);

return new FunctionResult(this, textContent.Text, kernel.Culture, textContent.Metadata);
}
Expand All @@ -146,7 +146,7 @@ protected override async IAsyncEnumerable<T> InvokeCoreStreamingAsync<T>(
yield break;
}

await foreach (var content in textGeneration.GetStreamingTextContentsAsync(renderedPrompt, arguments.ExecutionSettings, kernel, cancellationToken))
await foreach (var content in textGeneration.GetStreamingTextContentsWithDefaultParserAsync(renderedPrompt, arguments.ExecutionSettings, kernel, cancellationToken))
{
cancellationToken.ThrowIfCancellationRequested();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using Microsoft.SemanticKernel.AI.TextGeneration;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI;
using Moq;
Expand Down Expand Up @@ -120,4 +123,145 @@ public async Task ItFailsIfInvalidServiceIdIsProvidedAsync()
// Assert
Assert.Equal("Required service of type Microsoft.SemanticKernel.AI.TextGeneration.ITextGenerationService not registered. Expected serviceIds: service3.", exception.Message);
}

[Fact]
public async Task ItParsesStandardizedPromptWhenServiceIsChatCompletionAsync()
{
var fakeService = new FakeChatAsTextService();
var kernel = new KernelBuilder().WithServices(sc => { sc.AddTransient<ITextGenerationService>((sp) => fakeService); }).Build();

KernelFunction function = KernelFunctionFactory.CreateFromPrompt(@"
<message role=""system"">You are a helpful assistant.</message>
<message role=""user"">How many 20 cents can I get from 1 dollar?</message>
");

// Act + Assert
await kernel.InvokeAsync(function);

Assert.NotNull(fakeService.ChatHistory);
Assert.Equal(2, fakeService.ChatHistory.Count);
Assert.Equal("You are a helpful assistant.", fakeService.ChatHistory[0].Content);
Assert.Equal("How many 20 cents can I get from 1 dollar?", fakeService.ChatHistory[1].Content);
}

[Fact]
public async Task ItParsesStandardizedPromptWhenServiceIsStreamingChatCompletionAsync()
{
var fakeService = new FakeChatAsTextService();
var kernel = new KernelBuilder().WithServices(sc => { sc.AddTransient<ITextGenerationService>((sp) => fakeService); }).Build();

KernelFunction function = KernelFunctionFactory.CreateFromPrompt(@"
<message role=""system"">You are a helpful assistant.</message>
<message role=""user"">How many 20 cents can I get from 1 dollar?</message>
");

// Act + Assert
await foreach (var chunk in kernel.InvokeStreamingAsync(function))
{
}

Assert.NotNull(fakeService.ChatHistory);
Assert.Equal(2, fakeService.ChatHistory.Count);
Assert.Equal("You are a helpful assistant.", fakeService.ChatHistory[0].Content);
Assert.Equal("How many 20 cents can I get from 1 dollar?", fakeService.ChatHistory[1].Content);
}

[Fact]
public async Task ItNotParsesStandardizedPromptWhenServiceIsOnlyTextCompletionAsync()
{
var mockService = new Mock<ITextGenerationService>();
var mockResult = mockService.Setup(s => s.GetTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new List<TextContent>() { new("something") });

var kernel = new KernelBuilder().WithServices(sc => { sc.AddTransient<ITextGenerationService>((sp) => mockService.Object); }).Build();

var inputPrompt = @"
<message role=""system"">You are a helpful assistant.</message>
<message role=""user"">How many 20 cents can I get from 1 dollar?</message>
";

KernelFunction function = KernelFunctionFactory.CreateFromPrompt(inputPrompt);

// Act + Assert
mockResult.Callback((string prompt, PromptExecutionSettings _, Kernel _, CancellationToken _) =>
{
Assert.NotNull(prompt);
Assert.Equal(inputPrompt, prompt);
});

await kernel.InvokeAsync(function);
}

[Fact]
public async Task ItNotParsesStandardizedPromptWhenStreamingInServiceIsOnlyTextCompletionAsync()
{
var mockService = new Mock<ITextGenerationService>();
var mockResult = mockService.Setup(s => s.GetStreamingTextContentsAsync(It.IsAny<string>(), It.IsAny<PromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()))
.Returns(this.ToAsyncEnumerable<StreamingTextContent>(new List<StreamingTextContent>() { new("something") }));

var kernel = new KernelBuilder().WithServices(sc => { sc.AddTransient<ITextGenerationService>((sp) => mockService.Object); }).Build();

var inputPrompt = @"
<message role=""system"">You are a helpful assistant.</message>
<message role=""user"">How many 20 cents can I get from 1 dollar?</message>
";

KernelFunction function = KernelFunctionFactory.CreateFromPrompt(inputPrompt);

// Act + Assert
mockResult.Callback((string prompt, PromptExecutionSettings _, Kernel _, CancellationToken _) =>
{
Assert.NotNull(prompt);
Assert.Equal(inputPrompt, prompt);
});

await foreach (var chunk in kernel.InvokeStreamingAsync(function))
{
}
}

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
#pragma warning disable IDE1006 // Naming Styles
private async IAsyncEnumerable<T> ToAsyncEnumerable<T>(IEnumerable<T> enumeration)
#pragma warning restore IDE1006 // Naming Styles
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
{
foreach (var enumerationItem in enumeration)
{
yield return enumerationItem;
}
}

private sealed class FakeChatAsTextService : ITextGenerationService, IChatCompletionService
{
public IReadOnlyDictionary<string, object?> Attributes => throw new NotImplementedException();
public ChatHistory? ChatHistory { get; private set; }

public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
this.ChatHistory = chatHistory;

return Task.FromResult<IReadOnlyList<ChatMessageContent>>(new List<ChatMessageContent> { new(AuthorRole.Assistant, "Something") });
}

#pragma warning disable IDE0036 // Order modifiers
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
#pragma warning restore IDE0036 // Order modifiers
{
this.ChatHistory = chatHistory;
yield return new StreamingChatMessageContent(AuthorRole.Assistant, "Something");
}

public IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}

public Task<IReadOnlyList<TextContent>> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
}
}

0 comments on commit 136e7c2

Please sign in to comment.