diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props index ee3dfc9c6c28c..df417b1aaa291 100644 --- a/eng/Packages.Data.props +++ b/eng/Packages.Data.props @@ -157,6 +157,7 @@ + diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs index 3cb0ac044770f..f9c814a1b6905 100644 --- a/sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs +++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs @@ -138,9 +138,21 @@ namespace Azure.Provisioning.CloudMachine.OpenAI { public static partial class AzureOpenAIExtensions { + public static Azure.Provisioning.CloudMachine.OpenAI.EmbeddingKnowledgebase CreateEmbeddingKnowledgebase(this Azure.Core.ClientWorkspace workspace) { throw null; } + public static Azure.Provisioning.CloudMachine.OpenAI.OpenAIConversation CreateOpenAIConversation(this Azure.Core.ClientWorkspace workspace) { throw null; } public static OpenAI.Chat.ChatClient GetOpenAIChatClient(this Azure.Core.ClientWorkspace workspace) { throw null; } public static OpenAI.Embeddings.EmbeddingClient GetOpenAIEmbeddingsClient(this Azure.Core.ClientWorkspace workspace) { throw null; } } + public partial class EmbeddingKnowledgebase + { + internal EmbeddingKnowledgebase() { } + public void Add(string fact) { } + } + public partial class OpenAIConversation + { + internal OpenAIConversation() { } + public string Say(string message) { throw null; } + } public partial class OpenAIFeature : Azure.Provisioning.CloudMachine.CloudMachineFeature { public OpenAIFeature(AiModel chatDeployment, AiModel? embeddingsDeployment = null) { } diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/Azure.Provisioning.CloudMachine.csproj b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/Azure.Provisioning.CloudMachine.csproj index f0cd4eb52d2d9..a233df6e84542 100644 --- a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/Azure.Provisioning.CloudMachine.csproj +++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/Azure.Provisioning.CloudMachine.csproj @@ -23,6 +23,7 @@ + diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/EmbeddingKnowledgebase.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/EmbeddingKnowledgebase.cs new file mode 100644 index 0000000000000..64c16ae5833b3 --- /dev/null +++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/EmbeddingKnowledgebase.cs @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using OpenAI.Embeddings; + +namespace Azure.Provisioning.CloudMachine.OpenAI; + +/// +/// Represents a knowledgebase of facts represented by embeddings that can be used to find relevant facts based on a given text. +/// +public class EmbeddingKnowledgebase +{ + private EmbeddingClient _client; + private List _factsToProcess = new List(); + + private List> _vectors = new List>(); + private List _facts = new List(); + + internal EmbeddingKnowledgebase(EmbeddingClient client) + { + _client = client; + } + + /// + /// Add a fact to the knowledgebase. + /// + /// The fact to add. + public void Add(string fact) + { + ChunkAndAddToFactsToProcess(fact, 1000); + ProcessUnprocessedFacts(); + } + + internal List FindRelevantFacts(string text, float threshold = 0.29f, int top = 3) + { + if (_factsToProcess.Count > 0) + ProcessUnprocessedFacts(); + + ReadOnlySpan textVector = ProcessFact(text).Span; + + var results = new List(); + var distances = new List<(float Distance, int Index)>(); + for (int index = 0; index < _vectors.Count; index++) + { + ReadOnlyMemory dbVector = _vectors[index]; + float distance = 1.0f - CosineSimilarity(dbVector.Span, textVector); + distances.Add((distance, index)); + } + distances.Sort(((float D1, int I1) v1, (float D2, int I2) v2) => v1.D1.CompareTo(v2.D2)); + + top = Math.Min(top, distances.Count); + for (int i = 0; i < top; i++) + { + var distance = distances[i].Distance; + if (distance > threshold) + break; + var index = distances[i].Index; + results.Add(new Fact(_facts[index], index)); + } + return results; + } + + private static float CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y) + { + float dot = 0, xSumSquared = 0, ySumSquared = 0; + + for (int i = 0; i < x.Length; i++) + { + dot += x[i] * y[i]; + xSumSquared += x[i] * x[i]; + ySumSquared += y[i] * y[i]; + } + return dot / (MathF.Sqrt(xSumSquared) * MathF.Sqrt(ySumSquared)); + } + + private void ProcessUnprocessedFacts() + { + if (_factsToProcess.Count == 0) + { + return; + } + var embeddings = _client.GenerateEmbeddings(_factsToProcess); + + foreach (var embedding in embeddings.Value) + { + _vectors.Add(embedding.ToFloats()); + _facts.Add(_factsToProcess[embedding.Index]); + } + + _factsToProcess.Clear(); + } + + private ReadOnlyMemory ProcessFact(string fact) + { + var embedding = _client.GenerateEmbedding(fact); + + return embedding.Value.ToFloats(); + } + + internal void ChunkAndAddToFactsToProcess(string text, int chunkSize) + { + if (chunkSize <= 0) + { + throw new ArgumentException("Chunk size must be greater than zero.", nameof(chunkSize)); + } + + int overlapSize = (int)(chunkSize * 0.15); + int stepSize = chunkSize - overlapSize; + ReadOnlySpan textSpan = text.AsSpan(); + + for (int i = 0; i < text.Length; i += stepSize) + { + while (i > 0 && !char.IsWhiteSpace(textSpan[i])) + { + i--; + } + if (i + chunkSize > text.Length) + { + _factsToProcess.Add(textSpan.Slice(i).ToString()); + } + else + { + int end = i + chunkSize; + if (end > text.Length) + { + _factsToProcess.Add(textSpan.Slice(i).ToString()); + } + else + { + while (end < text.Length && !char.IsWhiteSpace(textSpan[end])) + { + end++; + } + _factsToProcess.Add(textSpan.Slice(i, end - i).ToString()); + } + } + } + } + internal struct Fact + { + public Fact(string text, int id) + { + Text = text; + Id = id; + } + + public string Text { get; set; } + public int Id { get; set; } + } +} diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIConversation.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIConversation.cs new file mode 100644 index 0000000000000..1477301031362 --- /dev/null +++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIConversation.cs @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using OpenAI.Chat; +using static Azure.Provisioning.CloudMachine.OpenAI.EmbeddingKnowledgebase; + +namespace Azure.Provisioning.CloudMachine.OpenAI; + +/// +/// Represents a conversation with the OpenAI chat model, incorporating a knowledgebase of embeddings data. +/// +public class OpenAIConversation +{ + private readonly ChatClient _client; + private readonly Prompt _prompt; + private readonly Dictionary _tools = new(); + private readonly EmbeddingKnowledgebase _knowledgebase; + private readonly ChatCompletionOptions _options = new ChatCompletionOptions(); + + /// + /// Initializes a new instance of the class. + /// + /// The ChatClient. + /// Any ChatTools to be used by the conversation. + /// The knowledgebase. + internal OpenAIConversation(ChatClient client, IEnumerable tools, EmbeddingKnowledgebase knowledgebase) + { + foreach (var tool in tools) + { + _options.Tools.Add(tool); + _tools.Add(tool.FunctionName, tool); + } + _client = client; + _knowledgebase = knowledgebase; + _prompt = new Prompt(); + _prompt.AddTools(tools); + } + + /// + /// Sends a message to the OpenAI chat model and returns the response, incorporating any relevant knowledge from the . + /// + /// + /// + public string Say(string message) + { + List facts = _knowledgebase.FindRelevantFacts(message); + _prompt.AddFacts(facts); + _prompt.AddUserMessage(message); + var response = CallOpenAI(); + return response; + } + + private string CallOpenAI() + { + bool requiresAction; + do + { + requiresAction = false; + var completion = _client.CompleteChat(_prompt.Messages).Value; + switch (completion.FinishReason) + { + case ChatFinishReason.ToolCalls: + // TODO: Implement tool calls + requiresAction = true; + break; + case ChatFinishReason.Length: + return "Incomplete model output due to MaxTokens parameter or token limit exceeded."; + case ChatFinishReason.ContentFilter: + return "Omitted content due to a content filter flag."; + case ChatFinishReason.Stop: + _prompt.AddAssistantMessage(new AssistantChatMessage(completion)); + break; + default: + throw new NotImplementedException("Unknown finish reason."); + } + return _prompt.GetSayResult(); + } while (requiresAction); + } + + internal class Prompt + { + internal readonly List userChatMessages = new(); + internal readonly List systemChatMessages = new(); + internal readonly List assistantChatMessages = new(); + internal readonly List toolChatMessages = new(); + internal readonly List chatCompletions = new(); + internal readonly List _tools = new(); + internal readonly List _factsAlreadyInPrompt = new List(); + + public Prompt() + { } + + public IEnumerable Messages + { + get + { + foreach (var message in systemChatMessages) + { + yield return message; + } + foreach (var message in userChatMessages) + { + yield return message; + } + foreach (var message in assistantChatMessages) + { + yield return message; + } + foreach (var message in toolChatMessages) + { + yield return message; + } + } + } + + //public ChatCompletionOptions Current => _prompt; + public void AddTools(IEnumerable tools) + { + foreach (var tool in tools) + { + _tools.Add(tool); + } + } + public void AddFacts(IEnumerable facts) + { + var sb = new StringBuilder(); + foreach (var fact in facts) + { + if (_factsAlreadyInPrompt.Contains(fact.Id)) + continue; + sb.AppendLine(fact.Text); + _factsAlreadyInPrompt.Add(fact.Id); + } + if (sb.Length > 0) + { + systemChatMessages.Add(ChatMessage.CreateSystemMessage(sb.ToString())); + } + } + public void AddUserMessage(string message) + { + userChatMessages.Add(ChatMessage.CreateUserMessage(message)); + } + public void AddAssistantMessage(string message) + { + assistantChatMessages.Add(ChatMessage.CreateAssistantMessage(message)); + } + public void AddAssistantMessage(AssistantChatMessage message) + { + assistantChatMessages.Add(message); + } + public void AddToolMessage(ToolChatMessage message) + { + toolChatMessages.Add(message); + } + + internal string GetSayResult() + { + var result = string.Join("\n", assistantChatMessages.Select(m => m.Content[0].Text)); + assistantChatMessages.Clear(); + userChatMessages.Clear(); + systemChatMessages.Clear(); + return result; + } + } +} diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs index 5ebf2c79a3a27..829dd4e2480e8 100644 --- a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs +++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs @@ -3,6 +3,7 @@ using System; using System.ClientModel; +using System.Linq; using Azure.AI.OpenAI; using Azure.Core; using Azure.Provisioning.Authorization; @@ -113,6 +114,19 @@ public static EmbeddingClient GetOpenAIEmbeddingsClient(this ClientWorkspace wor return embeddingsClient; } + public static EmbeddingKnowledgebase CreateEmbeddingKnowledgebase(this ClientWorkspace workspace) + { + EmbeddingClient embeddingsClient = workspace.GetOpenAIEmbeddingsClient(); + return new EmbeddingKnowledgebase(embeddingsClient); + } + + public static OpenAIConversation CreateOpenAIConversation(this ClientWorkspace workspace) + { + ChatClient chatClient = workspace.GetOpenAIChatClient(); + EmbeddingKnowledgebase knowledgebase = workspace.CreateEmbeddingKnowledgebase(); + return new OpenAIConversation(chatClient, [], knowledgebase); + } + private static AzureOpenAIClient CreateAzureOpenAIClient(this ClientWorkspace workspace) { ClientConnectionOptions connection = workspace.GetConnectionOptions(typeof(AzureOpenAIClient)); diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs index 3c2077ecc4490..671fa9f2930d4 100644 --- a/sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs +++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs @@ -31,6 +31,8 @@ public void Provisioning(string[] args) CloudMachineWorkspace cm = new(); Console.WriteLine(cm.Id); var embeddings = cm.GetOpenAIEmbeddingsClient(); + var kb = cm.CreateEmbeddingKnowledgebase(); + var conversation = cm.CreateOpenAIConversation(); } [Ignore("no recordings yet")]