Skip to content

Commit

Permalink
Add knowledgebase and conversation APIs (#46893)
Browse files Browse the repository at this point in the history
  • Loading branch information
christothes authored Oct 29, 2024
1 parent 42e72d7 commit 3643e92
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 0 deletions.
1 change: 1 addition & 0 deletions eng/Packages.Data.props
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
<PackageReference Update="Azure.Provisioning.KeyVault" Version="1.0.0" />
<PackageReference Update="Azure.Provisioning.ServiceBus" Version="1.0.0" />
<PackageReference Update="Azure.Provisioning.Storage" Version="1.0.0" />
<PackageReference Update="Microsoft.Bcl.Numerics" Version="8.0.0" />

<!-- Other approved packages -->
<PackageReference Update="Microsoft.Azure.Amqp" Version="2.6.7" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,21 @@ namespace Azure.Provisioning.CloudMachine.OpenAI
{
public static partial class AzureOpenAIExtensions
{
public static Azure.Provisioning.CloudMachine.OpenAI.EmbeddingKnowledgebase CreateEmbeddingKnowledgebase(this Azure.Core.ClientWorkspace workspace) { throw null; }
public static Azure.Provisioning.CloudMachine.OpenAI.OpenAIConversation CreateOpenAIConversation(this Azure.Core.ClientWorkspace workspace) { throw null; }
public static OpenAI.Chat.ChatClient GetOpenAIChatClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
public static OpenAI.Embeddings.EmbeddingClient GetOpenAIEmbeddingsClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
}
public partial class EmbeddingKnowledgebase
{
internal EmbeddingKnowledgebase() { }
public void Add(string fact) { }
}
public partial class OpenAIConversation
{
internal OpenAIConversation() { }
public string Say(string message) { throw null; }
}
public partial class OpenAIFeature : Azure.Provisioning.CloudMachine.CloudMachineFeature
{
public OpenAIFeature(AiModel chatDeployment, AiModel? embeddingsDeployment = null) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<PackageReference Include="Azure.Provisioning.EventGrid" />
<PackageReference Include="Azure.Security.KeyVault.Secrets" />
<PackageReference Include="Microsoft.Extensions.Configuration.Abstractions" VersionOverride="8.0.0" />
<PackageReference Include="Microsoft.Bcl.Numerics" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using OpenAI.Embeddings;

namespace Azure.Provisioning.CloudMachine.OpenAI;

/// <summary>
/// Represents a knowledgebase of facts represented by embeddings that can be used to find relevant facts based on a given text.
/// </summary>
public class EmbeddingKnowledgebase
{
private EmbeddingClient _client;
private List<string> _factsToProcess = new List<string>();

private List<ReadOnlyMemory<float>> _vectors = new List<ReadOnlyMemory<float>>();
private List<string> _facts = new List<string>();

internal EmbeddingKnowledgebase(EmbeddingClient client)
{
_client = client;
}

/// <summary>
/// Add a fact to the knowledgebase.
/// </summary>
/// <param name="fact">The fact to add.</param>
public void Add(string fact)
{
ChunkAndAddToFactsToProcess(fact, 1000);
ProcessUnprocessedFacts();
}

internal List<Fact> FindRelevantFacts(string text, float threshold = 0.29f, int top = 3)
{
if (_factsToProcess.Count > 0)
ProcessUnprocessedFacts();

ReadOnlySpan<float> textVector = ProcessFact(text).Span;

var results = new List<Fact>();
var distances = new List<(float Distance, int Index)>();
for (int index = 0; index < _vectors.Count; index++)
{
ReadOnlyMemory<float> dbVector = _vectors[index];
float distance = 1.0f - CosineSimilarity(dbVector.Span, textVector);
distances.Add((distance, index));
}
distances.Sort(((float D1, int I1) v1, (float D2, int I2) v2) => v1.D1.CompareTo(v2.D2));

top = Math.Min(top, distances.Count);
for (int i = 0; i < top; i++)
{
var distance = distances[i].Distance;
if (distance > threshold)
break;
var index = distances[i].Index;
results.Add(new Fact(_facts[index], index));
}
return results;
}

private static float CosineSimilarity(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
{
float dot = 0, xSumSquared = 0, ySumSquared = 0;

for (int i = 0; i < x.Length; i++)
{
dot += x[i] * y[i];
xSumSquared += x[i] * x[i];
ySumSquared += y[i] * y[i];
}
return dot / (MathF.Sqrt(xSumSquared) * MathF.Sqrt(ySumSquared));
}

private void ProcessUnprocessedFacts()
{
if (_factsToProcess.Count == 0)
{
return;
}
var embeddings = _client.GenerateEmbeddings(_factsToProcess);

foreach (var embedding in embeddings.Value)
{
_vectors.Add(embedding.ToFloats());
_facts.Add(_factsToProcess[embedding.Index]);
}

_factsToProcess.Clear();
}

private ReadOnlyMemory<float> ProcessFact(string fact)
{
var embedding = _client.GenerateEmbedding(fact);

return embedding.Value.ToFloats();
}

internal void ChunkAndAddToFactsToProcess(string text, int chunkSize)
{
if (chunkSize <= 0)
{
throw new ArgumentException("Chunk size must be greater than zero.", nameof(chunkSize));
}

int overlapSize = (int)(chunkSize * 0.15);
int stepSize = chunkSize - overlapSize;
ReadOnlySpan<char> textSpan = text.AsSpan();

for (int i = 0; i < text.Length; i += stepSize)
{
while (i > 0 && !char.IsWhiteSpace(textSpan[i]))
{
i--;
}
if (i + chunkSize > text.Length)
{
_factsToProcess.Add(textSpan.Slice(i).ToString());
}
else
{
int end = i + chunkSize;
if (end > text.Length)
{
_factsToProcess.Add(textSpan.Slice(i).ToString());
}
else
{
while (end < text.Length && !char.IsWhiteSpace(textSpan[end]))
{
end++;
}
_factsToProcess.Add(textSpan.Slice(i, end - i).ToString());
}
}
}
}
internal struct Fact
{
public Fact(string text, int id)
{
Text = text;
Id = id;
}

public string Text { get; set; }
public int Id { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using OpenAI.Chat;
using static Azure.Provisioning.CloudMachine.OpenAI.EmbeddingKnowledgebase;

namespace Azure.Provisioning.CloudMachine.OpenAI;

/// <summary>
/// Represents a conversation with the OpenAI chat model, incorporating a knowledgebase of embeddings data.
/// </summary>
public class OpenAIConversation
{
private readonly ChatClient _client;
private readonly Prompt _prompt;
private readonly Dictionary<string, ChatTool> _tools = new();
private readonly EmbeddingKnowledgebase _knowledgebase;
private readonly ChatCompletionOptions _options = new ChatCompletionOptions();

/// <summary>
/// Initializes a new instance of the <see cref="OpenAIConversation"/> class.
/// </summary>
/// <param name="client">The ChatClient.</param>
/// <param name="tools">Any ChatTools to be used by the conversation.</param>
/// <param name="knowledgebase">The knowledgebase.</param>
internal OpenAIConversation(ChatClient client, IEnumerable<ChatTool> tools, EmbeddingKnowledgebase knowledgebase)
{
foreach (var tool in tools)
{
_options.Tools.Add(tool);
_tools.Add(tool.FunctionName, tool);
}
_client = client;
_knowledgebase = knowledgebase;
_prompt = new Prompt();
_prompt.AddTools(tools);
}

/// <summary>
/// Sends a message to the OpenAI chat model and returns the response, incorporating any relevant knowledge from the <see cref="EmbeddingKnowledgebase"/>.
/// </summary>
/// <param name="message"></param>
/// <returns></returns>
public string Say(string message)
{
List<Fact> facts = _knowledgebase.FindRelevantFacts(message);
_prompt.AddFacts(facts);
_prompt.AddUserMessage(message);
var response = CallOpenAI();
return response;
}

private string CallOpenAI()
{
bool requiresAction;
do
{
requiresAction = false;
var completion = _client.CompleteChat(_prompt.Messages).Value;
switch (completion.FinishReason)
{
case ChatFinishReason.ToolCalls:
// TODO: Implement tool calls
requiresAction = true;
break;
case ChatFinishReason.Length:
return "Incomplete model output due to MaxTokens parameter or token limit exceeded.";
case ChatFinishReason.ContentFilter:
return "Omitted content due to a content filter flag.";
case ChatFinishReason.Stop:
_prompt.AddAssistantMessage(new AssistantChatMessage(completion));
break;
default:
throw new NotImplementedException("Unknown finish reason.");
}
return _prompt.GetSayResult();
} while (requiresAction);
}

internal class Prompt
{
internal readonly List<UserChatMessage> userChatMessages = new();
internal readonly List<SystemChatMessage> systemChatMessages = new();
internal readonly List<AssistantChatMessage> assistantChatMessages = new();
internal readonly List<ToolChatMessage> toolChatMessages = new();
internal readonly List<ChatCompletion> chatCompletions = new();
internal readonly List<ChatTool> _tools = new();
internal readonly List<int> _factsAlreadyInPrompt = new List<int>();

public Prompt()
{ }

public IEnumerable<ChatMessage> Messages
{
get
{
foreach (var message in systemChatMessages)
{
yield return message;
}
foreach (var message in userChatMessages)
{
yield return message;
}
foreach (var message in assistantChatMessages)
{
yield return message;
}
foreach (var message in toolChatMessages)
{
yield return message;
}
}
}

//public ChatCompletionOptions Current => _prompt;
public void AddTools(IEnumerable<ChatTool> tools)
{
foreach (var tool in tools)
{
_tools.Add(tool);
}
}
public void AddFacts(IEnumerable<Fact> facts)
{
var sb = new StringBuilder();
foreach (var fact in facts)
{
if (_factsAlreadyInPrompt.Contains(fact.Id))
continue;
sb.AppendLine(fact.Text);
_factsAlreadyInPrompt.Add(fact.Id);
}
if (sb.Length > 0)
{
systemChatMessages.Add(ChatMessage.CreateSystemMessage(sb.ToString()));
}
}
public void AddUserMessage(string message)
{
userChatMessages.Add(ChatMessage.CreateUserMessage(message));
}
public void AddAssistantMessage(string message)
{
assistantChatMessages.Add(ChatMessage.CreateAssistantMessage(message));
}
public void AddAssistantMessage(AssistantChatMessage message)
{
assistantChatMessages.Add(message);
}
public void AddToolMessage(ToolChatMessage message)
{
toolChatMessages.Add(message);
}

internal string GetSayResult()
{
var result = string.Join("\n", assistantChatMessages.Select(m => m.Content[0].Text));
assistantChatMessages.Clear();
userChatMessages.Clear();
systemChatMessages.Clear();
return result;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.ClientModel;
using System.Linq;
using Azure.AI.OpenAI;
using Azure.Core;
using Azure.Provisioning.Authorization;
Expand Down Expand Up @@ -113,6 +114,19 @@ public static EmbeddingClient GetOpenAIEmbeddingsClient(this ClientWorkspace wor
return embeddingsClient;
}

public static EmbeddingKnowledgebase CreateEmbeddingKnowledgebase(this ClientWorkspace workspace)
{
EmbeddingClient embeddingsClient = workspace.GetOpenAIEmbeddingsClient();
return new EmbeddingKnowledgebase(embeddingsClient);
}

public static OpenAIConversation CreateOpenAIConversation(this ClientWorkspace workspace)
{
ChatClient chatClient = workspace.GetOpenAIChatClient();
EmbeddingKnowledgebase knowledgebase = workspace.CreateEmbeddingKnowledgebase();
return new OpenAIConversation(chatClient, [], knowledgebase);
}

private static AzureOpenAIClient CreateAzureOpenAIClient(this ClientWorkspace workspace)
{
ClientConnectionOptions connection = workspace.GetConnectionOptions(typeof(AzureOpenAIClient));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public void Provisioning(string[] args)
CloudMachineWorkspace cm = new();
Console.WriteLine(cm.Id);
var embeddings = cm.GetOpenAIEmbeddingsClient();
var kb = cm.CreateEmbeddingKnowledgebase();
var conversation = cm.CreateOpenAIConversation();
}

[Ignore("no recordings yet")]
Expand Down

0 comments on commit 3643e92

Please sign in to comment.