Skip to content

Refactored the RAG process #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Things we are currently working on:
- [ ] App: Implement the process to vectorize one local file using embeddings
- [ ] Runtime: Integration of the vector database [LanceDB](https://github.com/lancedb/lancedb)
- [ ] App: Implement the continuous process of vectorizing data
- [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284), [#286](https://github.com/MindWorkAI/AI-Studio/pull/286))~~
- [x] ~~App: Define a common retrieval context interface for the integration of RAG processes in chats (PR [#281](https://github.com/MindWorkAI/AI-Studio/pull/281), [#284](https://github.com/MindWorkAI/AI-Studio/pull/284), [#286](https://github.com/MindWorkAI/AI-Studio/pull/286), [#287](https://github.com/MindWorkAI/AI-Studio/pull/287))~~
- [ ] App: Define a common augmentation interface for the integration of RAG processes in chats
- [x] ~~App: Integrate data sources in chats (PR [#282](https://github.com/MindWorkAI/AI-Studio/pull/282))~~

Expand Down
202 changes: 9 additions & 193 deletions app/MindWork AI Studio/Chat/ContentText.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
using System.Text.Json.Serialization;

using AIStudio.Agents;
using AIStudio.Components;
using AIStudio.Provider;
using AIStudio.Settings;
using AIStudio.Tools.RAG;
using AIStudio.Tools.Services;
using AIStudio.Tools.RAG.RAGProcesses;

namespace AIStudio.Chat;

Expand Down Expand Up @@ -43,204 +40,23 @@ public async Task CreateFromProviderAsync(IProvider provider, Model chatModel, I
{
if(chatThread is null)
return;

var logger = Program.SERVICE_PROVIDER.GetService<ILogger<ContentText>>()!;
var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;
var dataSourceService = Program.SERVICE_PROVIDER.GetService<DataSourceService>()!;

//
// 1. Check if the user wants to bind any data sources to the chat:
//
if (chatThread.DataSourceOptions.IsEnabled() && lastPrompt is not null)
// Call the RAG process. Right now, we only have one RAG process:
if (lastPrompt is not null)
{
logger.LogInformation("Data sources are enabled for this chat.");

// Across the different code-branches, we keep track of whether it
// makes sense to proceed with the RAG process:
var proceedWithRAG = true;

//
// When the user wants to bind data sources to the chat, we
// have to check if the data sources are available for the
// selected provider. Also, we have to check if any ERI
// data sources changed its security requirements.
//
List<IDataSource> preselectedDataSources = chatThread.DataSourceOptions.PreselectedDataSourceIds.Select(id => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == id)).Where(ds => ds is not null).ToList()!;
var dataSources = await dataSourceService.GetDataSources(provider, preselectedDataSources);
var selectedDataSources = dataSources.SelectedDataSources;

//
// Should the AI select the data sources?
//
if (chatThread.DataSourceOptions.AutomaticDataSourceSelection)
{
// Get the agent for the data source selection:
var selectionAgent = Program.SERVICE_PROVIDER.GetService<AgentDataSourceSelection>()!;

// Let the AI agent do its work:
IReadOnlyList<DataSourceAgentSelected> finalAISelection = [];
var aiSelectedDataSources = await selectionAgent.PerformSelectionAsync(provider, lastPrompt, chatThread, dataSources, token);

// Check if the AI selected any data sources:
if(aiSelectedDataSources.Count is 0)
{
logger.LogWarning("The AI did not select any data sources. The RAG process is skipped.");
proceedWithRAG = false;

// Send the selected data sources to the data source selection component.
// Then, the user can see which data sources were selected by the AI.
await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection);
chatThread.AISelectedDataSources = finalAISelection;
}
else
{
// Log the selected data sources:
var selectedDataSourceInfo = aiSelectedDataSources.Select(ds => $"[Id={ds.Id}, reason={ds.Reason}, confidence={ds.Confidence}]").Aggregate((a, b) => $"'{a}', '{b}'");
logger.LogInformation($"The AI selected the data sources automatically. {aiSelectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");

//
// Check how many data sources were hallucinated by the AI:
//
var totalAISelectedDataSources = aiSelectedDataSources.Count;

// Filter out the data sources that are not available:
aiSelectedDataSources = aiSelectedDataSources.Where(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id) is not null).ToList();

// Store the real AI-selected data sources:
finalAISelection = aiSelectedDataSources.Select(x => new DataSourceAgentSelected { DataSource = settings.ConfigurationData.DataSources.First(ds => ds.Id == x.Id), AIDecision = x, Selected = false }).ToList();

var numHallucinatedSources = totalAISelectedDataSources - aiSelectedDataSources.Count;
if(numHallucinatedSources > 0)
logger.LogWarning($"The AI hallucinated {numHallucinatedSources} data source(s). We ignore them.");

if (aiSelectedDataSources.Count > 3)
{
//
// We have more than 3 data sources. Let's filter by confidence.
// In order to do that, we must identify the lower and upper
// bounds of the confidence interval:
//
var confidenceValues = aiSelectedDataSources.Select(x => x.Confidence).ToList();
var lowerBound = confidenceValues.Min();
var upperBound = confidenceValues.Max();

//
// Next, we search for a threshold so that we have between 2 and 3
// data sources. When not possible, we take all data sources.
//
var threshold = 0.0f;

// Check the case where the confidence values are too close:
if (upperBound - lowerBound >= 0.01)
{
var previousThreshold = 0.0f;
for (var i = 0; i < 10; i++)
{
threshold = lowerBound + (upperBound - lowerBound) * i / 10;
var numMatches = aiSelectedDataSources.Count(x => x.Confidence >= threshold);
if (numMatches <= 1)
{
threshold = previousThreshold;
break;
}

if (numMatches is <= 3 and >= 2)
break;

previousThreshold = threshold;
}
}

//
// Filter the data sources by the threshold:
//
aiSelectedDataSources = aiSelectedDataSources.Where(x => x.Confidence >= threshold).ToList();
foreach (var dataSource in finalAISelection)
if(aiSelectedDataSources.Any(x => x.Id == dataSource.DataSource.Id))
dataSource.Selected = true;

logger.LogInformation($"The AI selected {aiSelectedDataSources.Count} data source(s) with a confidence of at least {threshold}.");

// Transform the final data sources to the actual data sources:
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;
}

// We have max. 3 data sources. We take all of them:
else
{
// Transform the selected data sources to the actual data sources:
selectedDataSources = aiSelectedDataSources.Select(x => settings.ConfigurationData.DataSources.FirstOrDefault(ds => ds.Id == x.Id)).Where(ds => ds is not null).ToList()!;

// Mark the data sources as selected:
foreach (var dataSource in finalAISelection)
dataSource.Selected = true;
}

// Send the selected data sources to the data source selection component.
// Then, the user can see which data sources were selected by the AI.
await MessageBus.INSTANCE.SendMessage(null, Event.RAG_AUTO_DATA_SOURCES_SELECTED, finalAISelection);
chatThread.AISelectedDataSources = finalAISelection;
}
}
else
{
//
// No, the user made the choice manually:
//
var selectedDataSourceInfo = selectedDataSources.Select(ds => ds.Name).Aggregate((a, b) => $"'{a}', '{b}'");
logger.LogInformation($"The user selected the data sources manually. {selectedDataSources.Count} data source(s) are selected: {selectedDataSourceInfo}.");
}

if(selectedDataSources.Count == 0)
{
logger.LogWarning("No data sources are selected. The RAG process is skipped.");
proceedWithRAG = false;
}

//
// Trigger the retrieval part of the (R)AG process:
//
var dataContexts = new List<IRetrievalContext>();
if (proceedWithRAG)
{
//
// We kick off the retrieval process for each data source in parallel:
//
var retrievalTasks = new List<Task<IReadOnlyList<IRetrievalContext>>>(selectedDataSources.Count);
foreach (var dataSource in selectedDataSources)
retrievalTasks.Add(dataSource.RetrieveDataAsync(lastPrompt, chatThread, token));

//
// Wait for all retrieval tasks to finish:
//
foreach (var retrievalTask in retrievalTasks)
{
try
{
dataContexts.AddRange(await retrievalTask);
}
catch (Exception e)
{
logger.LogError(e, "An error occurred during the retrieval process.");
}
}
}

//
// Perform the augmentation of the R(A)G process:
//
if (proceedWithRAG)
{

}
var rag = new AISrcSelWithRetCtxVal();
chatThread = await rag.ProcessAsync(provider, lastPrompt, chatThread, token);
}

// Store the last time we got a response. We use this later
// to determine whether we should notify the UI about the
// new content or not. Depends on the energy saving mode
// the user chose.
var last = DateTimeOffset.Now;

// Get the settings manager:
var settings = Program.SERVICE_PROVIDER.GetService<SettingsManager>()!;

// Start another thread by using a task to uncouple
// the UI thread from the AI processing:
await Task.Run(async () =>
Expand Down
10 changes: 10 additions & 0 deletions app/MindWork AI Studio/Tools/RAG/DataSelectionResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using AIStudio.Settings;

namespace AIStudio.Tools.RAG;

/// <summary>
/// Result of any data selection process.
/// </summary>
/// <param name="ProceedWithRAG">Makes it sense to proceed with the RAG process?</param>
/// <param name="SelectedDataSources">The selected data sources.</param>
public readonly record struct DataSelectionResult(bool ProceedWithRAG, IReadOnlyList<IDataSource> SelectedDataSources);
Loading