-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add knowledgebase and conversation APIs (#46893)
- Loading branch information
1 parent
42e72d7
commit 3643e92
Showing
7 changed files
with
351 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
152 changes: 152 additions & 0 deletions
152
...isioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/EmbeddingKnowledgebase.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using OpenAI.Embeddings; | ||
|
||
namespace Azure.Provisioning.CloudMachine.OpenAI; | ||
|
||
/// <summary> | ||
/// Represents a knowledgebase of facts represented by embeddings that can be used to find relevant facts based on a given text. | ||
/// </summary> | ||
public class EmbeddingKnowledgebase | ||
{ | ||
private EmbeddingClient _client; | ||
private List<string> _factsToProcess = new List<string>(); | ||
|
||
private List<ReadOnlyMemory<float>> _vectors = new List<ReadOnlyMemory<float>>(); | ||
private List<string> _facts = new List<string>(); | ||
|
||
internal EmbeddingKnowledgebase(EmbeddingClient client) | ||
{ | ||
_client = client; | ||
} | ||
|
||
/// <summary> | ||
/// Add a fact to the knowledgebase. | ||
/// </summary> | ||
/// <param name="fact">The fact to add.</param> | ||
public void Add(string fact) | ||
{ | ||
ChunkAndAddToFactsToProcess(fact, 1000); | ||
ProcessUnprocessedFacts(); | ||
} | ||
|
||
internal List<Fact> FindRelevantFacts(string text, float threshold = 0.29f, int top = 3) | ||
{ | ||
if (_factsToProcess.Count > 0) | ||
ProcessUnprocessedFacts(); | ||
|
||
ReadOnlySpan<float> textVector = ProcessFact(text).Span; | ||
|
||
var results = new List<Fact>(); | ||
var distances = new List<(float Distance, int Index)>(); | ||
for (int index = 0; index < _vectors.Count; index++) | ||
{ | ||
ReadOnlyMemory<float> dbVector = _vectors[index]; | ||
float distance = 1.0f - CosineSimilarity(dbVector.Span, textVector); | ||
distances.Add((distance, index)); | ||
} | ||
distances.Sort(((float D1, int I1) v1, (float D2, int I2) v2) => v1.D1.CompareTo(v2.D2)); | ||
|
||
top = Math.Min(top, distances.Count); | ||
for (int i = 0; i < top; i++) | ||
{ | ||
var distance = distances[i].Distance; | ||
if (distance > threshold) | ||
break; | ||
var index = distances[i].Index; | ||
results.Add(new Fact(_facts[index], index)); | ||
} | ||
return results; | ||
} | ||
|
||
private static float CosineSimilarity(ReadOnlySpan<float> x, ReadOnlySpan<float> y) | ||
{ | ||
float dot = 0, xSumSquared = 0, ySumSquared = 0; | ||
|
||
for (int i = 0; i < x.Length; i++) | ||
{ | ||
dot += x[i] * y[i]; | ||
xSumSquared += x[i] * x[i]; | ||
ySumSquared += y[i] * y[i]; | ||
} | ||
return dot / (MathF.Sqrt(xSumSquared) * MathF.Sqrt(ySumSquared)); | ||
} | ||
|
||
private void ProcessUnprocessedFacts() | ||
{ | ||
if (_factsToProcess.Count == 0) | ||
{ | ||
return; | ||
} | ||
var embeddings = _client.GenerateEmbeddings(_factsToProcess); | ||
|
||
foreach (var embedding in embeddings.Value) | ||
{ | ||
_vectors.Add(embedding.ToFloats()); | ||
_facts.Add(_factsToProcess[embedding.Index]); | ||
} | ||
|
||
_factsToProcess.Clear(); | ||
} | ||
|
||
private ReadOnlyMemory<float> ProcessFact(string fact) | ||
{ | ||
var embedding = _client.GenerateEmbedding(fact); | ||
|
||
return embedding.Value.ToFloats(); | ||
} | ||
|
||
internal void ChunkAndAddToFactsToProcess(string text, int chunkSize) | ||
{ | ||
if (chunkSize <= 0) | ||
{ | ||
throw new ArgumentException("Chunk size must be greater than zero.", nameof(chunkSize)); | ||
} | ||
|
||
int overlapSize = (int)(chunkSize * 0.15); | ||
int stepSize = chunkSize - overlapSize; | ||
ReadOnlySpan<char> textSpan = text.AsSpan(); | ||
|
||
for (int i = 0; i < text.Length; i += stepSize) | ||
{ | ||
while (i > 0 && !char.IsWhiteSpace(textSpan[i])) | ||
{ | ||
i--; | ||
} | ||
if (i + chunkSize > text.Length) | ||
{ | ||
_factsToProcess.Add(textSpan.Slice(i).ToString()); | ||
} | ||
else | ||
{ | ||
int end = i + chunkSize; | ||
if (end > text.Length) | ||
{ | ||
_factsToProcess.Add(textSpan.Slice(i).ToString()); | ||
} | ||
else | ||
{ | ||
while (end < text.Length && !char.IsWhiteSpace(textSpan[end])) | ||
{ | ||
end++; | ||
} | ||
_factsToProcess.Add(textSpan.Slice(i, end - i).ToString()); | ||
} | ||
} | ||
} | ||
} | ||
internal struct Fact | ||
{ | ||
public Fact(string text, int id) | ||
{ | ||
Text = text; | ||
Id = id; | ||
} | ||
|
||
public string Text { get; set; } | ||
public int Id { get; set; } | ||
} | ||
} |
169 changes: 169 additions & 0 deletions
169
...provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIConversation.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
using OpenAI.Chat; | ||
using static Azure.Provisioning.CloudMachine.OpenAI.EmbeddingKnowledgebase; | ||
|
||
namespace Azure.Provisioning.CloudMachine.OpenAI; | ||
|
||
/// <summary> | ||
/// Represents a conversation with the OpenAI chat model, incorporating a knowledgebase of embeddings data. | ||
/// </summary> | ||
public class OpenAIConversation | ||
{ | ||
private readonly ChatClient _client; | ||
private readonly Prompt _prompt; | ||
private readonly Dictionary<string, ChatTool> _tools = new(); | ||
private readonly EmbeddingKnowledgebase _knowledgebase; | ||
private readonly ChatCompletionOptions _options = new ChatCompletionOptions(); | ||
|
||
/// <summary> | ||
/// Initializes a new instance of the <see cref="OpenAIConversation"/> class. | ||
/// </summary> | ||
/// <param name="client">The ChatClient.</param> | ||
/// <param name="tools">Any ChatTools to be used by the conversation.</param> | ||
/// <param name="knowledgebase">The knowledgebase.</param> | ||
internal OpenAIConversation(ChatClient client, IEnumerable<ChatTool> tools, EmbeddingKnowledgebase knowledgebase) | ||
{ | ||
foreach (var tool in tools) | ||
{ | ||
_options.Tools.Add(tool); | ||
_tools.Add(tool.FunctionName, tool); | ||
} | ||
_client = client; | ||
_knowledgebase = knowledgebase; | ||
_prompt = new Prompt(); | ||
_prompt.AddTools(tools); | ||
} | ||
|
||
/// <summary> | ||
/// Sends a message to the OpenAI chat model and returns the response, incorporating any relevant knowledge from the <see cref="EmbeddingKnowledgebase"/>. | ||
/// </summary> | ||
/// <param name="message"></param> | ||
/// <returns></returns> | ||
public string Say(string message) | ||
{ | ||
List<Fact> facts = _knowledgebase.FindRelevantFacts(message); | ||
_prompt.AddFacts(facts); | ||
_prompt.AddUserMessage(message); | ||
var response = CallOpenAI(); | ||
return response; | ||
} | ||
|
||
private string CallOpenAI() | ||
{ | ||
bool requiresAction; | ||
do | ||
{ | ||
requiresAction = false; | ||
var completion = _client.CompleteChat(_prompt.Messages).Value; | ||
switch (completion.FinishReason) | ||
{ | ||
case ChatFinishReason.ToolCalls: | ||
// TODO: Implement tool calls | ||
requiresAction = true; | ||
break; | ||
case ChatFinishReason.Length: | ||
return "Incomplete model output due to MaxTokens parameter or token limit exceeded."; | ||
case ChatFinishReason.ContentFilter: | ||
return "Omitted content due to a content filter flag."; | ||
case ChatFinishReason.Stop: | ||
_prompt.AddAssistantMessage(new AssistantChatMessage(completion)); | ||
break; | ||
default: | ||
throw new NotImplementedException("Unknown finish reason."); | ||
} | ||
return _prompt.GetSayResult(); | ||
} while (requiresAction); | ||
} | ||
|
||
internal class Prompt | ||
{ | ||
internal readonly List<UserChatMessage> userChatMessages = new(); | ||
internal readonly List<SystemChatMessage> systemChatMessages = new(); | ||
internal readonly List<AssistantChatMessage> assistantChatMessages = new(); | ||
internal readonly List<ToolChatMessage> toolChatMessages = new(); | ||
internal readonly List<ChatCompletion> chatCompletions = new(); | ||
internal readonly List<ChatTool> _tools = new(); | ||
internal readonly List<int> _factsAlreadyInPrompt = new List<int>(); | ||
|
||
public Prompt() | ||
{ } | ||
|
||
public IEnumerable<ChatMessage> Messages | ||
{ | ||
get | ||
{ | ||
foreach (var message in systemChatMessages) | ||
{ | ||
yield return message; | ||
} | ||
foreach (var message in userChatMessages) | ||
{ | ||
yield return message; | ||
} | ||
foreach (var message in assistantChatMessages) | ||
{ | ||
yield return message; | ||
} | ||
foreach (var message in toolChatMessages) | ||
{ | ||
yield return message; | ||
} | ||
} | ||
} | ||
|
||
//public ChatCompletionOptions Current => _prompt; | ||
public void AddTools(IEnumerable<ChatTool> tools) | ||
{ | ||
foreach (var tool in tools) | ||
{ | ||
_tools.Add(tool); | ||
} | ||
} | ||
public void AddFacts(IEnumerable<Fact> facts) | ||
{ | ||
var sb = new StringBuilder(); | ||
foreach (var fact in facts) | ||
{ | ||
if (_factsAlreadyInPrompt.Contains(fact.Id)) | ||
continue; | ||
sb.AppendLine(fact.Text); | ||
_factsAlreadyInPrompt.Add(fact.Id); | ||
} | ||
if (sb.Length > 0) | ||
{ | ||
systemChatMessages.Add(ChatMessage.CreateSystemMessage(sb.ToString())); | ||
} | ||
} | ||
public void AddUserMessage(string message) | ||
{ | ||
userChatMessages.Add(ChatMessage.CreateUserMessage(message)); | ||
} | ||
public void AddAssistantMessage(string message) | ||
{ | ||
assistantChatMessages.Add(ChatMessage.CreateAssistantMessage(message)); | ||
} | ||
public void AddAssistantMessage(AssistantChatMessage message) | ||
{ | ||
assistantChatMessages.Add(message); | ||
} | ||
public void AddToolMessage(ToolChatMessage message) | ||
{ | ||
toolChatMessages.Add(message); | ||
} | ||
|
||
internal string GetSayResult() | ||
{ | ||
var result = string.Join("\n", assistantChatMessages.Select(m => m.Content[0].Text)); | ||
assistantChatMessages.Clear(); | ||
userChatMessages.Clear(); | ||
systemChatMessages.Clear(); | ||
return result; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters