Skip to content

Add retrieval integration, part I #281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Things we are currently working on:
- [ ] App: Implement the process to vectorize one local file using embeddings
- [ ] Runtime: Integration of the vector database [LanceDB](https://github.com/lancedb/lancedb)
- [ ] App: Implement the continuous process of vectorizing data
- [ ] App: Define a common retrieval context interface for the integration of RAG processes in chats
- [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281))~~
- [ ] App: Define a common augmentation interface for the integration of RAG processes in chats
- [ ] App: Integrate data sources in chats

Expand Down
14 changes: 9 additions & 5 deletions app/MindWork AI Studio/Agents/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ public abstract class AgentBase(ILogger<AgentBase> logger, SettingsManager setti
protected ThreadSafeRandom RNG { get; init; } = rng;

protected ILogger<AgentBase> Logger { get; init; } = logger;

protected IContent? lastUserPrompt;

/// <summary>
/// Represents the type or category of this agent.
Expand Down Expand Up @@ -63,15 +65,17 @@ public abstract class AgentBase(ILogger<AgentBase> logger, SettingsManager setti
protected DateTimeOffset AddUserRequest(ChatThread thread, string request)
{
var time = DateTimeOffset.Now;
this.lastUserPrompt = new ContentText
{
Text = request,
};

thread.Blocks.Add(new ContentBlock
{
Time = time,
ContentType = ContentType.TEXT,
Role = ChatRole.USER,
Content = new ContentText
{
Text = request,
},
Content = this.lastUserPrompt,
});

return time;
Expand Down Expand Up @@ -103,6 +107,6 @@ protected async Task AddAIResponseAsync(ChatThread thread, DateTimeOffset time)
// Use the selected provider to get the AI response.
// By awaiting this line, we wait for the entire
// content to be streamed.
await aiText.CreateFromProviderAsync(providerSettings.CreateProvider(this.Logger), this.SettingsManager, providerSettings.Model, thread);
await aiText.CreateFromProviderAsync(providerSettings.CreateProvider(this.Logger), this.SettingsManager, providerSettings.Model, this.lastUserPrompt, thread);
}
}
13 changes: 8 additions & 5 deletions app/MindWork AI Studio/Assistants/AssistantBase.razor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ public abstract partial class AssistantBase : ComponentBase, IMessageBusReceiver
protected bool inputIsValid;
protected Profile currentProfile = Profile.NO_PROFILE;
protected ChatThread? chatThread;
protected IContent? lastUserPrompt;

private readonly Timer formChangeTimer = new(TimeSpan.FromSeconds(1.6));

Expand Down Expand Up @@ -242,16 +243,18 @@ protected Guid CreateChatThread(Guid workspaceId, string name)
protected DateTimeOffset AddUserRequest(string request, bool hideContentFromUser = false)
{
var time = DateTimeOffset.Now;
this.lastUserPrompt = new ContentText
{
Text = request,
};

this.chatThread!.Blocks.Add(new ContentBlock
{
Time = time,
ContentType = ContentType.TEXT,
HideFromUser = hideContentFromUser,
Role = ChatRole.USER,
Content = new ContentText
{
Text = request,
},
Content = this.lastUserPrompt,
});

return time;
Expand Down Expand Up @@ -287,7 +290,7 @@ protected async Task<string> AddAIResponseAsync(DateTimeOffset time, bool hideCo
// Use the selected provider to get the AI response.
// By awaiting this line, we wait for the entire
// content to be streamed.
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.SettingsManager, this.providerSettings.Model, this.chatThread);
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.SettingsManager, this.providerSettings.Model, this.lastUserPrompt, this.chatThread);

this.isProcessing = false;
this.StateHasChanged();
Expand Down
4 changes: 2 additions & 2 deletions app/MindWork AI Studio/Chat/ContentBlockComponent.razor
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@
break;

case ContentType.IMAGE:
if (this.Content is ContentImage imageContent)
if (this.Content is ContentImage { SourceType: ContentImageSource.URL or ContentImageSource.LOCAL_PATH } imageContent)
{
<MudImage Src="@imageContent.URL"/>
<MudImage Src="@imageContent.Source"/>
}

break;
Expand Down
13 changes: 8 additions & 5 deletions app/MindWork AI Studio/Chat/ContentImage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,23 @@ public sealed class ContentImage : IContent
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;

/// <inheritdoc />
public Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, ChatThread chatChatThread, CancellationToken token = default)
public Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default)
{
throw new NotImplementedException();
}

#endregion

/// <summary>
/// The URL of the image.
/// The type of the image source.
/// </summary>
public string URL { get; set; } = string.Empty;
/// <remarks>
/// Is the image source a URL, a local file path, a base64 string, etc.?
/// </remarks>
public required ContentImageSource SourceType { get; init; }

/// <summary>
/// The local path of the image.
/// The image source.
/// </summary>
public string LocalPath { get; set; } = string.Empty;
public required string Source { get; set; }
}
8 changes: 8 additions & 0 deletions app/MindWork AI Studio/Chat/ContentImageSource.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace AIStudio.Chat;

public enum ContentImageSource
{
URL,
LOCAL_PATH,
BASE64,
}
14 changes: 13 additions & 1 deletion app/MindWork AI Studio/Chat/ContentText.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,23 @@ public sealed class ContentText : IContent
public Func<Task> StreamingEvent { get; set; } = () => Task.CompletedTask;

/// <inheritdoc />
public async Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, ChatThread? chatThread, CancellationToken token = default)
public async Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, IContent? lastPrompt, ChatThread? chatThread, CancellationToken token = default)
{
if(chatThread is null)
return;

//
// Check if the user wants to bind any data sources to the chat:
//

//
// Trigger the retrieval part of the (R)AG process:
//

//
// Perform the augmentation of the R(A)G process:
//

// Store the last time we got a response. We use this later
// to determine whether we should notify the UI about the
// new content or not. Depends on the energy saving mode
Expand Down
2 changes: 1 addition & 1 deletion app/MindWork AI Studio/Chat/IContent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ public interface IContent
/// <summary>
/// Uses the provider to create the content.
/// </summary>
public Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, ChatThread chatChatThread, CancellationToken token = default);
public Task CreateFromProviderAsync(IProvider provider, SettingsManager settings, Model chatModel, IContent? lastPrompt, ChatThread? chatChatThread, CancellationToken token = default);
}
15 changes: 10 additions & 5 deletions app/MindWork AI Studio/Components/ChatComponent.razor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,14 @@ private async Task SendMessage(bool reuseLastUserPrompt = false)
}

var time = DateTimeOffset.Now;
IContent? lastUserPrompt;
if (!reuseLastUserPrompt)
{
lastUserPrompt = new ContentText
{
Text = this.userInput,
};

//
// Add the user message to the thread:
//
Expand All @@ -305,10 +311,7 @@ private async Task SendMessage(bool reuseLastUserPrompt = false)
Time = time,
ContentType = ContentType.TEXT,
Role = ChatRole.USER,
Content = new ContentText
{
Text = this.userInput,
},
Content = lastUserPrompt,
});

// Save the chat:
Expand All @@ -319,6 +322,8 @@ private async Task SendMessage(bool reuseLastUserPrompt = false)
this.StateHasChanged();
}
}
else
lastUserPrompt = this.ChatThread.Blocks.Last(x => x.Role is ChatRole.USER).Content;

//
// Add the AI response to the thread:
Expand Down Expand Up @@ -360,7 +365,7 @@ private async Task SendMessage(bool reuseLastUserPrompt = false)
// Use the selected provider to get the AI response.
// By awaiting this line, we wait for the entire
// content to be streamed.
await aiText.CreateFromProviderAsync(this.Provider.CreateProvider(this.Logger), this.SettingsManager, this.Provider.Model, this.ChatThread, this.cancellationTokenSource.Token);
await aiText.CreateFromProviderAsync(this.Provider.CreateProvider(this.Logger), this.SettingsManager, this.Provider.Model, lastUserPrompt, this.ChatThread, this.cancellationTokenSource.Token);
}

this.cancellationTokenSource = null;
Expand Down
14 changes: 8 additions & 6 deletions app/MindWork AI Studio/Pages/Writer.razor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,19 @@ You are an assistant who helps with writing documents. You receive a sample
};

var time = DateTimeOffset.Now;
var lastUserPrompt = new ContentText
{
// We use the maximum 160 characters from the end of the text:
Text = this.userInput.Length > 160 ? this.userInput[^160..] : this.userInput,
};

this.chatThread.Blocks.Clear();
this.chatThread.Blocks.Add(new ContentBlock
{
Time = time,
ContentType = ContentType.TEXT,
Role = ChatRole.USER,
Content = new ContentText
{
// We use the maximum 160 characters from the end of the text:
Text = this.userInput.Length > 160 ? this.userInput[^160..] : this.userInput,
},
Content = lastUserPrompt,
});

var aiText = new ContentText
Expand All @@ -137,7 +139,7 @@ You are an assistant who helps with writing documents. You receive a sample
this.isStreaming = true;
this.StateHasChanged();

await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.SettingsManager, this.providerSettings.Model, this.chatThread);
await aiText.CreateFromProviderAsync(this.providerSettings.CreateProvider(this.Logger), this.SettingsManager, this.providerSettings.Model, lastUserPrompt, this.chatThread);
this.suggestion = aiText.Text;

this.isStreaming = false;
Expand Down
45 changes: 45 additions & 0 deletions app/MindWork AI Studio/Tools/RAG/IRetrievalContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
namespace AIStudio.Tools.RAG;

/// <summary>
/// The common interface for any retrieval context.
/// </summary>
public interface IRetrievalContext
{
/// <summary>
/// The name of the data source.
/// </summary>
/// <remarks>
/// Depending on the configuration, the AI is selecting the appropriate data source.
/// In order to inform the user about where the information is coming from, the data
/// source name is necessary.
/// </remarks>
public string DataSourceName { get; init; }

/// <summary>
/// The category of the content, like e.g., text, audio, image, etc.
/// </summary>
public RetrievalContentCategory Category { get; init; }

/// <summary>
/// What type of content is being retrieved? Like e.g., a project proposal, spreadsheet, art, etc.
/// </summary>
public RetrievalContentType Type { get; init; }

/// <summary>
/// The path to the content, e.g., a URL, a file path, a path in a graph database, etc.
/// </summary>
public string Path { get; init; }

/// <summary>
/// Links to related content, e.g., links to Wikipedia articles, links to sources, etc.
/// </summary>
/// <remarks>
/// Why would you need links for retrieval? You are right that not all retrieval
/// contexts need links. But think about a web search feature, where we want to
/// query a search engine and get back a list of links to the most relevant
/// matches. Think about a continuous web crawler that is constantly looking for
/// new information and adding it to the knowledge base. In these cases, links
/// are essential.
/// </remarks>
public IReadOnlyList<string> Links { get; init; }
}
12 changes: 12 additions & 0 deletions app/MindWork AI Studio/Tools/RAG/RetrievalContentCategory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace AIStudio.Tools.RAG;

public enum RetrievalContentCategory
{
NONE,
UNKNOWN,

TEXT,
IMAGE,
VIDEO,
AUDIO,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using AIStudio.Tools.ERIClient.DataModel;

namespace AIStudio.Tools.RAG;

public static class RetrievalContentCategoryExtensions
{
/// <summary>
/// Converts an ERI content type to a common retrieval content category.
/// </summary>
/// <param name="contentType">The content type yielded by the ERI server.</param>
/// <returns>The corresponding retrieval content category.</returns>
public static RetrievalContentCategory ToRetrievalContentCategory(ContentType contentType) => contentType switch
{
ContentType.NONE => RetrievalContentCategory.NONE,
ContentType.UNKNOWN => RetrievalContentCategory.UNKNOWN,
ContentType.TEXT => RetrievalContentCategory.TEXT,
ContentType.IMAGE => RetrievalContentCategory.IMAGE,
ContentType.VIDEO => RetrievalContentCategory.VIDEO,
ContentType.AUDIO => RetrievalContentCategory.AUDIO,
ContentType.SPEECH => RetrievalContentCategory.AUDIO,

_ => RetrievalContentCategory.UNKNOWN,
};
}
Loading