Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/HealthData/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ void GiveAgentInstructions()
// Supply current date and time, and how to use it
instructions += PromptLibrary.Now();
instructions += "Help me enter my health data step by step.\n" +
"Ask specific questions to gather required OR optional fields I have not already provided" +
"Ask specific questions to gather required and optional fields I have not already provided" +
"Stop asking if I don't know the answer\n" +
"Automatically fix my spelling mistakes\n" +
"My health data may be complex: always record and return ALL of it.\n" +
"Always return a response. If you don't understand what I say, ask a question.";
"Always return a response:\n" +
"- If you don't understand what I say, ask a question.\n" +
"- At least respond with an OK message.";
}

public override Task ProcessCommandAsync(string cmd, IList<string> args)
Expand Down
89 changes: 65 additions & 24 deletions src/typechat.dialog/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ public class Agent<T> : IAgent
JsonTranslator<T> _translator;
IContextProvider? _contextProvider;
Prompt _instructions;
PromptBuilder _builder;
int _maxPromptLength;

/// <summary>
Expand All @@ -38,8 +37,9 @@ public Agent(JsonTranslator<T> translator, IContextProvider? contextProvider = n
ArgumentVerify.ThrowIfNull(translator, nameof(translator));
_translator = translator;
_instructions = new Prompt();
_builder = new PromptBuilder(translator.Model.ModelInfo.MaxCharCount / 2);
_maxPromptLength = _builder.MaxLength;
// By default, only use 1/2 the estimated # of characters the model supports.. for prompts
// the Agent sends
_maxPromptLength = translator.Model.ModelInfo.MaxCharCount / 2;
_contextProvider = contextProvider;
}

Expand Down Expand Up @@ -73,12 +73,27 @@ public async Task<Message> GetResponseMessageAsync(Message requestMessage, Cance
ArgumentVerify.ThrowIfNull(requestMessage, nameof(requestMessage));

string requestText = requestMessage.GetText();
Prompt context = BuildContext(requestText);
string preparedRequestText = requestText;
//
// Prepare the actual message to send to the model
//
Message preparedRequestMessage = await PrepareRequestAsync(requestMessage, cancelToken);
if (!object.ReferenceEquals(preparedRequestMessage, requestMessage))
{
preparedRequestText = preparedRequestMessage.GetText();
}
//
// Prepare the context to send. For context building, use the original request text
//
Prompt context = await BuildContextAsync(requestText, preparedRequestText.Length, cancelToken);
//
// Translate
//
T response = await _translator.TranslateAsync(preparedRequestText, context, null, cancelToken).ConfigureAwait(false);

T response = await _translator.TranslateAsync(requestText, context, null, cancelToken).ConfigureAwait(false);
Message responseMessage = Message.FromAssistant(response);

OnReceivedResponse(requestMessage, responseMessage);
await ReceivedResponseAsync(requestMessage, preparedRequestMessage, responseMessage);

return responseMessage;
}
Expand All @@ -95,35 +110,61 @@ public async Task<T> GetResponseAsync(Message request, CancellationToken cancelT
return response.GetBody<T>();
}

Prompt BuildContext(string requestText)
async Task<Prompt> BuildContextAsync(string requestText, int actualRequestLength, CancellationToken cancelToken)
{
int requestLength = requestText.Length;
//
// Since are single threaded, we can keep reusing the same builder
//
_builder.Clear();
_builder.MaxLength = (_maxPromptLength - requestLength);
PromptBuilder builder = CreateBuilder(_maxPromptLength - actualRequestLength);
// Add any preamble
_builder.AddRange(_instructions);
builder.AddRange(_instructions);
//
// If a context provider is available, inject additional context
//
if (_contextProvider != null)
{
IEnumerable<IPromptSection> context = _contextProvider.GetContext(requestText);
if (context != null)
{
_builder.Add(PromptSection.Instruction("IMPORTANT CONTEXT for the user request:"));
AppendContext(_builder, context);
}
var context = _contextProvider.GetContextAsync(requestText, cancelToken);
builder.Add(PromptSection.Instruction("IMPORTANT CONTEXT for the user request:"));
await AppendContextAsync(builder, context);
}
return _builder.Prompt;
return builder.Prompt;
}

internal virtual void AppendContext(PromptBuilder builder, IEnumerable<IPromptSection> context)
/// <summary>
/// Override to customize the actual message sent to the model. Several scenarios involve
/// transforming the user's message in various ways first
/// By default, the request message is sent to the model as is
/// </summary>
/// <param name="request">request message</param>
/// <param name="cancelToken">cancel </param>
/// <returns>Actual text to send to the AI</returns>
protected virtual Task<Message> PrepareRequestAsync(Message request, CancellationToken cancelToken)
{
return Task.FromResult(request);
}
/// <summary>
/// Override to customize how user context is added to the given prompt builder
/// Since the builder will limit the # of characters, you may want to do some pre processing
/// </summary>
/// <param name="builder">builder to append to</param>
/// <param name="context">context to append</param>
/// <returns></returns>
protected virtual Task<bool> AppendContextAsync(PromptBuilder builder, IAsyncEnumerable<IPromptSection> context)
{
builder.AddRange(context);
return builder.AddRangeAsync(context);
}
/// <summary>
/// Invoked when a valid response was received - the response is placed in the message body
/// </summary>
/// <param name="request">request message</param>
/// <param name="preparedRequest">the prepared request that was actually used in translation</param>
/// <param name="response">response message</param>
/// <returns></returns>
protected virtual Task ReceivedResponseAsync(Message request, Message preparedRequest, Message response)
{
return Task.CompletedTask;
}

protected virtual void OnReceivedResponse(Message request, Message response) { }
PromptBuilder CreateBuilder(int maxLength)
{
// Future: Pool these
return new PromptBuilder(maxLength);
}
}
19 changes: 14 additions & 5 deletions src/typechat.dialog/AgentWithHistory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,29 @@ public AgentWithHistory(JsonTranslator<T> translator, IMessageStream history)
/// </summary>
public Func<T, Message?> CreateMessageForHistory { get; set; }

internal override void AppendContext(PromptBuilder builder, IEnumerable<IPromptSection> context)
protected override Task<bool> AppendContextAsync(PromptBuilder builder, IAsyncEnumerable<IPromptSection> context)
{
builder.AddHistory(context);
return builder.AddHistoryAsync(context);
}

protected override void OnReceivedResponse(Message request, Message response)
/// <summary>
/// This is where we append messages to history
/// - User message is saved as is
/// - The response is further (optionally) transformed and then saved
/// </summary>
/// <param name="request">user request</param>
/// <param name="preparedRequest"></param>
/// <param name="response"></param>
/// <returns></returns>
protected async override Task ReceivedResponseAsync(Message request, Message preparedRequest, Message response)
{
_history.Append(request);
await _history.AppendAsync(request);
Message? historyMessage = (CreateMessageForHistory != null) ?
CreateMessageForHistory(response.GetBody<T>()) :
response;
if (historyMessage != null)
{
_history.Append(historyMessage);
await _history.AppendAsync(historyMessage);
}
}
}
21 changes: 20 additions & 1 deletion src/typechat.dialog/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public static string Stringify(this object obj)
/// Will keep adding messages until the prompt runs out of room
/// </summary>
/// <param name="builder">builder used to build prompt</param>
/// <param name="history">message history to add</param>
/// <param name="context">message history to add</param>
/// <returns></returns>
public static bool AddHistory(this PromptBuilder builder, IEnumerable<IPromptSection> context)
{
Expand All @@ -35,4 +35,23 @@ public static bool AddHistory(this PromptBuilder builder, IEnumerable<IPromptSec
}
return retVal;
}

/// <summary>
/// Add messages in priority order to the prompt
/// Will keep adding messages until the prompt runs out of room
/// </summary>
/// <param name="builder"></param>
/// <param name="context"></param>
/// <returns></returns>
public static async Task<bool> AddHistoryAsync(this PromptBuilder builder, IAsyncEnumerable<IPromptSection> context)
{
int contextStartAt = builder.Prompt.Count;
bool retVal = await builder.AddRangeAsync(context);
int contextEndAt = builder.Prompt.Count;
if (contextStartAt < contextEndAt)
{
builder.Prompt.Reverse(contextStartAt, contextEndAt - contextStartAt);
}
return retVal;
}
}
8 changes: 7 additions & 1 deletion src/typechat.dialog/IContextProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,11 @@ namespace Microsoft.TypeChat.Dialog;
/// </summary>
public interface IContextProvider
{
IEnumerable<IPromptSection>? GetContext(string request);
/// <summary>
/// Return relevant context for this this request
/// </summary>
/// <param name="request">user request</param>
/// <param name="cancelToken">optional cancel token</param>
/// <returns></returns>
IAsyncEnumerable<IPromptSection> GetContextAsync(string request, CancellationToken cancelToken);
}
19 changes: 17 additions & 2 deletions src/typechat.dialog/IMessageStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,31 @@ public interface IMessageStream : IContextProvider
/// <param name="message">message to append</param>
void Append(Message message);
/// <summary>
/// Append a message to the stream
/// </summary>
/// <param name="message">message to append</param>
Task AppendAsync(Message message);
/// <summary>
/// Return all messages in the stream
/// </summary>
/// <returns></returns>
/// <returns>an enumeration of messages</returns>
IEnumerable<Message> All();
/// <summary>
/// Return all messages in the stream
/// </summary>
/// <returns>An async enumeration of messages</returns>
IAsyncEnumerable<Message> AllAsync(CancellationToken cancelToken);
/// <summary>
/// Return the newest messages in the stream in order - most recent messages first
/// </summary>
/// <returns></returns>
/// <returns>an enumeration of messages</returns>
IEnumerable<Message> Newest();
/// <summary>
/// Return the newest messages in the stream in order - most recent messages first
/// </summary>
/// <returns>An async enumeration of messages</returns>
IAsyncEnumerable<Message> NewestAsync(CancellationToken cancelToken);
/// <summary>
/// Clear the stream
/// </summary>
void Clear();
Expand Down
1 change: 1 addition & 0 deletions src/typechat.dialog/Includes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
global using System.Text.Json;
global using System.Text.Json.Serialization;
global using System.Threading.Tasks;
global using System.Runtime.CompilerServices;
global using Microsoft.TypeChat;
1 change: 1 addition & 0 deletions src/typechat.dialog/Message.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public Message(string from, object body)

/// <summary>
/// Return the message body serialized as text
/// You can override this to cache generated text
/// </summary>
/// <returns>body as text</returns>
public virtual string GetText()
Expand Down
56 changes: 55 additions & 1 deletion src/typechat.dialog/MessageList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ public MessageList(int capacity = 4)
/// <param name="message">message to append</param>
public void Append(Message message) => Add(message);

/// <summary>
/// Append a message to the message stream
/// </summary>
/// <param name="message">message to append</param>
public Task AppendAsync(Message message)
{
Add(message);
return Task.CompletedTask;
}

/// <summary>
/// Return an enumeration of messages, most recent first
/// </summary>
Expand All @@ -60,7 +70,7 @@ public IEnumerable<Message> Newest()
}

/// <summary>
/// Just returns the newest messages. You can build other message lists that support semantic and
/// Just returns messages in order of newest first. You can build other message lists that support semantic and
/// other forms of similarity
/// supply
/// </summary>
Expand All @@ -71,6 +81,50 @@ public IEnumerable<IPromptSection> GetContext(string request)
return Newest();
}

#pragma warning disable 1998

/// <summary>
///
/// </summary>
/// <param name="cancelToken"></param>
/// <returns>async enumeration</returns>
public async IAsyncEnumerable<Message> AllAsync([EnumeratorCancellation] CancellationToken cancelToken = default)
{
for (int i = 0; i < Count; ++i)
{
yield return this[i];
}
}

/// <summary>
/// Enumerate messages asynchronously - newest first
/// </summary>
/// <param name="cancelToken"></param>
/// <returns>async enumeration</returns>
public async IAsyncEnumerable<Message> NewestAsync([EnumeratorCancellation] CancellationToken cancelToken = default)
{
for (int i = Count - 1; i >= 0; --i)
{
yield return this[i];
}
}

/// <summary>
/// Just returns messages in order of newest first. You can build other message lists that support semantic and
/// </summary>
/// <param name="request">find messages nearest to this</param>
/// <param name="cancelToken">optional cancel token</param>
/// <returns></returns>
public async IAsyncEnumerable<IPromptSection>? GetContextAsync(string request, [EnumeratorCancellation] CancellationToken cancelToken)
{
for (int i = Count - 1; i >= 0; --i)
{
yield return this[i];
}
}

#pragma warning restore 1998

/// <summary>
/// Close the message stream. MessageList does nothing here
/// </summary>
Expand Down
Loading