diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props
index ee3dfc9c6c28c..df417b1aaa291 100644
--- a/eng/Packages.Data.props
+++ b/eng/Packages.Data.props
@@ -157,6 +157,7 @@
+
diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs
index 3cb0ac044770f..f9c814a1b6905 100644
--- a/sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs
+++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs
@@ -138,9 +138,21 @@ namespace Azure.Provisioning.CloudMachine.OpenAI
{
public static partial class AzureOpenAIExtensions
{
+ public static Azure.Provisioning.CloudMachine.OpenAI.EmbeddingKnowledgebase CreateEmbeddingKnowledgebase(this Azure.Core.ClientWorkspace workspace) { throw null; }
+ public static Azure.Provisioning.CloudMachine.OpenAI.OpenAIConversation CreateOpenAIConversation(this Azure.Core.ClientWorkspace workspace) { throw null; }
public static OpenAI.Chat.ChatClient GetOpenAIChatClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
public static OpenAI.Embeddings.EmbeddingClient GetOpenAIEmbeddingsClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
}
+ public partial class EmbeddingKnowledgebase
+ {
+ internal EmbeddingKnowledgebase() { }
+ public void Add(string fact) { }
+ }
+ public partial class OpenAIConversation
+ {
+ internal OpenAIConversation() { }
+ public string Say(string message) { throw null; }
+ }
public partial class OpenAIFeature : Azure.Provisioning.CloudMachine.CloudMachineFeature
{
public OpenAIFeature(AiModel chatDeployment, AiModel? embeddingsDeployment = null) { }
diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/Azure.Provisioning.CloudMachine.csproj b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/Azure.Provisioning.CloudMachine.csproj
index f0cd4eb52d2d9..a233df6e84542 100644
--- a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/Azure.Provisioning.CloudMachine.csproj
+++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/Azure.Provisioning.CloudMachine.csproj
@@ -23,6 +23,7 @@
+
diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/EmbeddingKnowledgebase.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/EmbeddingKnowledgebase.cs
new file mode 100644
index 0000000000000..64c16ae5833b3
--- /dev/null
+++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/EmbeddingKnowledgebase.cs
@@ -0,0 +1,152 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using OpenAI.Embeddings;
+
+namespace Azure.Provisioning.CloudMachine.OpenAI;
+
+///
+/// Represents a knowledgebase of facts represented by embeddings that can be used to find relevant facts based on a given text.
+///
+public class EmbeddingKnowledgebase
+{
+ private EmbeddingClient _client;
+ private List _factsToProcess = new List();
+
+ private List> _vectors = new List>();
+ private List _facts = new List();
+
+ internal EmbeddingKnowledgebase(EmbeddingClient client)
+ {
+ _client = client;
+ }
+
+ ///
+ /// Add a fact to the knowledgebase.
+ ///
+ /// The fact to add.
+ public void Add(string fact)
+ {
+ ChunkAndAddToFactsToProcess(fact, 1000);
+ ProcessUnprocessedFacts();
+ }
+
+ internal List FindRelevantFacts(string text, float threshold = 0.29f, int top = 3)
+ {
+ if (_factsToProcess.Count > 0)
+ ProcessUnprocessedFacts();
+
+ ReadOnlySpan textVector = ProcessFact(text).Span;
+
+ var results = new List();
+ var distances = new List<(float Distance, int Index)>();
+ for (int index = 0; index < _vectors.Count; index++)
+ {
+ ReadOnlyMemory dbVector = _vectors[index];
+ float distance = 1.0f - CosineSimilarity(dbVector.Span, textVector);
+ distances.Add((distance, index));
+ }
+ distances.Sort(((float D1, int I1) v1, (float D2, int I2) v2) => v1.D1.CompareTo(v2.D2));
+
+ top = Math.Min(top, distances.Count);
+ for (int i = 0; i < top; i++)
+ {
+ var distance = distances[i].Distance;
+ if (distance > threshold)
+ break;
+ var index = distances[i].Index;
+ results.Add(new Fact(_facts[index], index));
+ }
+ return results;
+ }
+
+ private static float CosineSimilarity(ReadOnlySpan x, ReadOnlySpan y)
+ {
+ float dot = 0, xSumSquared = 0, ySumSquared = 0;
+
+ for (int i = 0; i < x.Length; i++)
+ {
+ dot += x[i] * y[i];
+ xSumSquared += x[i] * x[i];
+ ySumSquared += y[i] * y[i];
+ }
+ return dot / (MathF.Sqrt(xSumSquared) * MathF.Sqrt(ySumSquared));
+ }
+
+ private void ProcessUnprocessedFacts()
+ {
+ if (_factsToProcess.Count == 0)
+ {
+ return;
+ }
+ var embeddings = _client.GenerateEmbeddings(_factsToProcess);
+
+ foreach (var embedding in embeddings.Value)
+ {
+ _vectors.Add(embedding.ToFloats());
+ _facts.Add(_factsToProcess[embedding.Index]);
+ }
+
+ _factsToProcess.Clear();
+ }
+
+ private ReadOnlyMemory ProcessFact(string fact)
+ {
+ var embedding = _client.GenerateEmbedding(fact);
+
+ return embedding.Value.ToFloats();
+ }
+
+ internal void ChunkAndAddToFactsToProcess(string text, int chunkSize)
+ {
+ if (chunkSize <= 0)
+ {
+ throw new ArgumentException("Chunk size must be greater than zero.", nameof(chunkSize));
+ }
+
+ int overlapSize = (int)(chunkSize * 0.15);
+ int stepSize = chunkSize - overlapSize;
+ ReadOnlySpan textSpan = text.AsSpan();
+
+ for (int i = 0; i < text.Length; i += stepSize)
+ {
+ while (i > 0 && !char.IsWhiteSpace(textSpan[i]))
+ {
+ i--;
+ }
+ if (i + chunkSize > text.Length)
+ {
+ _factsToProcess.Add(textSpan.Slice(i).ToString());
+ }
+ else
+ {
+ int end = i + chunkSize;
+ if (end > text.Length)
+ {
+ _factsToProcess.Add(textSpan.Slice(i).ToString());
+ }
+ else
+ {
+ while (end < text.Length && !char.IsWhiteSpace(textSpan[end]))
+ {
+ end++;
+ }
+ _factsToProcess.Add(textSpan.Slice(i, end - i).ToString());
+ }
+ }
+ }
+ }
+ internal struct Fact
+ {
+ public Fact(string text, int id)
+ {
+ Text = text;
+ Id = id;
+ }
+
+ public string Text { get; set; }
+ public int Id { get; set; }
+ }
+}
diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIConversation.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIConversation.cs
new file mode 100644
index 0000000000000..1477301031362
--- /dev/null
+++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIConversation.cs
@@ -0,0 +1,169 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using OpenAI.Chat;
+using static Azure.Provisioning.CloudMachine.OpenAI.EmbeddingKnowledgebase;
+
+namespace Azure.Provisioning.CloudMachine.OpenAI;
+
+///
+/// Represents a conversation with the OpenAI chat model, incorporating a knowledgebase of embeddings data.
+///
+public class OpenAIConversation
+{
+ private readonly ChatClient _client;
+ private readonly Prompt _prompt;
+ private readonly Dictionary _tools = new();
+ private readonly EmbeddingKnowledgebase _knowledgebase;
+ private readonly ChatCompletionOptions _options = new ChatCompletionOptions();
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The ChatClient.
+ /// Any ChatTools to be used by the conversation.
+ /// The knowledgebase.
+ internal OpenAIConversation(ChatClient client, IEnumerable tools, EmbeddingKnowledgebase knowledgebase)
+ {
+ foreach (var tool in tools)
+ {
+ _options.Tools.Add(tool);
+ _tools.Add(tool.FunctionName, tool);
+ }
+ _client = client;
+ _knowledgebase = knowledgebase;
+ _prompt = new Prompt();
+ _prompt.AddTools(tools);
+ }
+
+ ///
+ /// Sends a message to the OpenAI chat model and returns the response, incorporating any relevant knowledge from the .
+ ///
+ ///
+ ///
+ public string Say(string message)
+ {
+ List facts = _knowledgebase.FindRelevantFacts(message);
+ _prompt.AddFacts(facts);
+ _prompt.AddUserMessage(message);
+ var response = CallOpenAI();
+ return response;
+ }
+
+ private string CallOpenAI()
+ {
+ bool requiresAction;
+ do
+ {
+ requiresAction = false;
+ var completion = _client.CompleteChat(_prompt.Messages).Value;
+ switch (completion.FinishReason)
+ {
+ case ChatFinishReason.ToolCalls:
+ // TODO: Implement tool calls
+ requiresAction = true;
+ break;
+ case ChatFinishReason.Length:
+ return "Incomplete model output due to MaxTokens parameter or token limit exceeded.";
+ case ChatFinishReason.ContentFilter:
+ return "Omitted content due to a content filter flag.";
+ case ChatFinishReason.Stop:
+ _prompt.AddAssistantMessage(new AssistantChatMessage(completion));
+ break;
+ default:
+ throw new NotImplementedException("Unknown finish reason.");
+ }
+ return _prompt.GetSayResult();
+ } while (requiresAction);
+ }
+
+ internal class Prompt
+ {
+ internal readonly List userChatMessages = new();
+ internal readonly List systemChatMessages = new();
+ internal readonly List assistantChatMessages = new();
+ internal readonly List toolChatMessages = new();
+ internal readonly List chatCompletions = new();
+ internal readonly List _tools = new();
+ internal readonly List _factsAlreadyInPrompt = new List();
+
+ public Prompt()
+ { }
+
+ public IEnumerable Messages
+ {
+ get
+ {
+ foreach (var message in systemChatMessages)
+ {
+ yield return message;
+ }
+ foreach (var message in userChatMessages)
+ {
+ yield return message;
+ }
+ foreach (var message in assistantChatMessages)
+ {
+ yield return message;
+ }
+ foreach (var message in toolChatMessages)
+ {
+ yield return message;
+ }
+ }
+ }
+
+ //public ChatCompletionOptions Current => _prompt;
+ public void AddTools(IEnumerable tools)
+ {
+ foreach (var tool in tools)
+ {
+ _tools.Add(tool);
+ }
+ }
+ public void AddFacts(IEnumerable facts)
+ {
+ var sb = new StringBuilder();
+ foreach (var fact in facts)
+ {
+ if (_factsAlreadyInPrompt.Contains(fact.Id))
+ continue;
+ sb.AppendLine(fact.Text);
+ _factsAlreadyInPrompt.Add(fact.Id);
+ }
+ if (sb.Length > 0)
+ {
+ systemChatMessages.Add(ChatMessage.CreateSystemMessage(sb.ToString()));
+ }
+ }
+ public void AddUserMessage(string message)
+ {
+ userChatMessages.Add(ChatMessage.CreateUserMessage(message));
+ }
+ public void AddAssistantMessage(string message)
+ {
+ assistantChatMessages.Add(ChatMessage.CreateAssistantMessage(message));
+ }
+ public void AddAssistantMessage(AssistantChatMessage message)
+ {
+ assistantChatMessages.Add(message);
+ }
+ public void AddToolMessage(ToolChatMessage message)
+ {
+ toolChatMessages.Add(message);
+ }
+
+ internal string GetSayResult()
+ {
+ var result = string.Join("\n", assistantChatMessages.Select(m => m.Content[0].Text));
+ assistantChatMessages.Clear();
+ userChatMessages.Clear();
+ systemChatMessages.Clear();
+ return result;
+ }
+ }
+}
diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs
index 5ebf2c79a3a27..829dd4e2480e8 100644
--- a/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs
+++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs
@@ -3,6 +3,7 @@
using System;
using System.ClientModel;
+using System.Linq;
using Azure.AI.OpenAI;
using Azure.Core;
using Azure.Provisioning.Authorization;
@@ -113,6 +114,19 @@ public static EmbeddingClient GetOpenAIEmbeddingsClient(this ClientWorkspace wor
return embeddingsClient;
}
+ public static EmbeddingKnowledgebase CreateEmbeddingKnowledgebase(this ClientWorkspace workspace)
+ {
+ EmbeddingClient embeddingsClient = workspace.GetOpenAIEmbeddingsClient();
+ return new EmbeddingKnowledgebase(embeddingsClient);
+ }
+
+ public static OpenAIConversation CreateOpenAIConversation(this ClientWorkspace workspace)
+ {
+ ChatClient chatClient = workspace.GetOpenAIChatClient();
+ EmbeddingKnowledgebase knowledgebase = workspace.CreateEmbeddingKnowledgebase();
+ return new OpenAIConversation(chatClient, [], knowledgebase);
+ }
+
private static AzureOpenAIClient CreateAzureOpenAIClient(this ClientWorkspace workspace)
{
ClientConnectionOptions connection = workspace.GetConnectionOptions(typeof(AzureOpenAIClient));
diff --git a/sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs b/sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs
index 3c2077ecc4490..671fa9f2930d4 100644
--- a/sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs
+++ b/sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs
@@ -31,6 +31,8 @@ public void Provisioning(string[] args)
CloudMachineWorkspace cm = new();
Console.WriteLine(cm.Id);
var embeddings = cm.GetOpenAIEmbeddingsClient();
+ var kb = cm.CreateEmbeddingKnowledgebase();
+ var conversation = cm.CreateOpenAIConversation();
}
[Ignore("no recordings yet")]