Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ private async IAsyncEnumerable<IngestionResult> ProcessAsync(IEnumerable<FileInf
processFileActivity?.SetTag(ProcessSource.DocumentIdTagName, document.Identifier);
_logger?.ReadDocument(document.Identifier);

await IngestAsync(document, processFileActivity, cancellationToken).ConfigureAwait(false);
document = await IngestAsync(document, processFileActivity, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand All @@ -164,12 +164,13 @@ private async IAsyncEnumerable<IngestionResult> ProcessAsync(IEnumerable<FileInf
failure = ex;
}

yield return new IngestionResult(fileInfo, document, failure);
string documentId = document?.Identifier ?? fileInfo.FullName;
yield return new IngestionResult(documentId, document, failure);
}
}
}

private async Task IngestAsync(IngestionDocument document, Activity? parentActivity, CancellationToken cancellationToken)
private async Task<IngestionDocument> IngestAsync(IngestionDocument document, Activity? parentActivity, CancellationToken cancellationToken)
{
foreach (IngestionDocumentProcessor processor in DocumentProcessors)
{
Expand All @@ -188,5 +189,7 @@ private async Task IngestAsync(IngestionDocument document, Activity? parentActiv
_logger?.WritingChunks(GetShortName(_writer));
await _writer.WriteAsync(chunks, cancellationToken).ConfigureAwait(false);
_logger?.WroteChunks(document.Identifier);

return document;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.IO;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.DataIngestion;
Expand All @@ -13,9 +12,9 @@ namespace Microsoft.Extensions.DataIngestion;
public sealed class IngestionResult
{
/// <summary>
/// Gets the source file that was ingested.
/// Gets the ID of the document that was ingested.
/// </summary>
public FileInfo Source { get; }
public string DocumentId { get; }

/// <summary>
/// Gets the ingestion document created from the source file, if reading the document has succeeded.
Expand All @@ -32,9 +31,9 @@ public sealed class IngestionResult
/// </summary>
public bool Succeeded => Exception is null;

internal IngestionResult(FileInfo source, IngestionDocument? document, Exception? exception)
internal IngestionResult(string documentId, IngestionDocument? document, Exception? exception)
{
Source = Throw.IfNull(source);
DocumentId = Throw.IfNullOrEmpty(documentId);
Document = document;
Exception = exception;
}
Expand Down
6 changes: 6 additions & 0 deletions src/Libraries/Microsoft.Extensions.DataIngestion/Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,11 @@ internal static partial class Log

[LoggerMessage(6, LogLevel.Error, "An error occurred while ingesting document '{identifier}'.")]
internal static partial void IngestingFailed(this ILogger logger, Exception exception, string identifier);

[LoggerMessage(7, LogLevel.Error, "The AI chat service returned {resultCount} instead of {expectedCount} results.")]
internal static partial void UnexpectedResultsCount(this ILogger logger, int resultCount, int expectedCount);

[LoggerMessage(8, LogLevel.Error, "Unexpected enricher failure.")]
internal static partial void UnexpectedEnricherFailure(this ILogger logger, Exception exception);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

<ItemGroup>
<ProjectReference Include="..\Microsoft.Extensions.DataIngestion.Abstractions\Microsoft.Extensions.DataIngestion.Abstractions.csproj" />
<ProjectReference Include="..\Microsoft.Extensions.AI\Microsoft.Extensions.AI.csproj" />
</ItemGroup>

<ItemGroup>
Expand All @@ -25,7 +26,6 @@

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
<PackageReference Include="System.Diagnostics.DiagnosticSource" />
<PackageReference Include="System.Collections.Immutable" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETFramework'">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Frozen;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.DataIngestion;
Expand All @@ -21,30 +19,28 @@ namespace Microsoft.Extensions.DataIngestion;
/// an optional fallback class for cases where no suitable classification can be determined.</remarks>
public sealed class ClassificationEnricher : IngestionChunkProcessor<string>
{
private readonly IChatClient _chatClient;
private readonly ChatOptions? _chatOptions;
private readonly FrozenSet<string> _predefinedClasses;
private readonly EnricherOptions _options;
private readonly ChatMessage _systemPrompt;
private readonly ILogger? _logger;

/// <summary>
/// Initializes a new instance of the <see cref="ClassificationEnricher"/> class.
/// </summary>
/// <param name="chatClient">The chat client used for classification.</param>
/// <param name="options">The options for the classification enricher.</param>
/// <param name="predefinedClasses">The set of predefined classification classes.</param>
/// <param name="chatOptions">Options for the chat client.</param>
/// <param name="fallbackClass">The fallback class to use when no suitable classification is found. When not provided, it defaults to "Unknown".</param>
public ClassificationEnricher(IChatClient chatClient, ReadOnlySpan<string> predefinedClasses,
ChatOptions? chatOptions = null, string? fallbackClass = null)
public ClassificationEnricher(EnricherOptions options, ReadOnlySpan<string> predefinedClasses,
string? fallbackClass = null)
{
_chatClient = Throw.IfNull(chatClient);
_chatOptions = chatOptions;
_options = Throw.IfNull(options).Clone();
if (string.IsNullOrWhiteSpace(fallbackClass))
{
fallbackClass = "Unknown";
}

_predefinedClasses = CreatePredefinedSet(predefinedClasses, fallbackClass!);
Validate(predefinedClasses, fallbackClass!);
_systemPrompt = CreateSystemPrompt(predefinedClasses, fallbackClass!);
_logger = _options.LoggerFactory?.CreateLogger<ClassificationEnricher>();
}

/// <summary>
Expand All @@ -53,28 +49,10 @@ public ClassificationEnricher(IChatClient chatClient, ReadOnlySpan<string> prede
public static string MetadataKey => "classification";

/// <inheritdoc />
public override async IAsyncEnumerable<IngestionChunk<string>> ProcessAsync(IAsyncEnumerable<IngestionChunk<string>> chunks,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(chunks);

await foreach (IngestionChunk<string> chunk in chunks.WithCancellation(cancellationToken))
{
var response = await _chatClient.GetResponseAsync(
[
_systemPrompt,
new(ChatRole.User, chunk.Content)
], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false);

chunk.Metadata[MetadataKey] = _predefinedClasses.Contains(response.Text)
? response.Text
: throw new InvalidOperationException($"Classification returned an unexpected class: '{response.Text}'.");

yield return chunk;
}
}
public override IAsyncEnumerable<IngestionChunk<string>> ProcessAsync(IAsyncEnumerable<IngestionChunk<string>> chunks, CancellationToken cancellationToken = default)
=> Batching.ProcessAsync<string>(chunks, _options, MetadataKey, _systemPrompt, _logger, cancellationToken);

private static FrozenSet<string> CreatePredefinedSet(ReadOnlySpan<string> predefinedClasses, string fallbackClass)
private static void Validate(ReadOnlySpan<string> predefinedClasses, string fallbackClass)
{
if (predefinedClasses.Length == 0)
{
Expand All @@ -84,15 +62,6 @@ private static FrozenSet<string> CreatePredefinedSet(ReadOnlySpan<string> predef
HashSet<string> predefinedClassesSet = new(StringComparer.Ordinal) { fallbackClass };
foreach (string predefinedClass in predefinedClasses)
{
#if NET
if (predefinedClass.Contains(',', StringComparison.Ordinal))
#else
if (predefinedClass.IndexOf(',') >= 0)
#endif
{
Throw.ArgumentException(nameof(predefinedClasses), $"Predefined class '{predefinedClass}' must not contain ',' character.");
}

if (!predefinedClassesSet.Add(predefinedClass))
{
if (predefinedClass.Equals(fallbackClass, StringComparison.Ordinal))
Expand All @@ -103,13 +72,11 @@ private static FrozenSet<string> CreatePredefinedSet(ReadOnlySpan<string> predef
Throw.ArgumentException(nameof(predefinedClasses), $"Duplicate class found: '{predefinedClass}'.");
}
}

return predefinedClassesSet.ToFrozenSet();
}

private static ChatMessage CreateSystemPrompt(ReadOnlySpan<string> predefinedClasses, string fallbackClass)
{
StringBuilder sb = new("You are a classification expert. Analyze the given text and assign a single, most relevant class. Use only the following predefined classes: ");
StringBuilder sb = new("You are a classification expert. For each of the following texts, assign a single, most relevant class. Use only the following predefined classes: ");

#if NET9_0_OR_GREATER
sb.AppendJoin(", ", predefinedClasses!);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.DataIngestion;

/// <summary>
/// Represents options for enrichers that use an AI chat client.
/// </summary>
public class EnricherOptions
{
/// <summary>
/// Initializes a new instance of the <see cref="EnricherOptions"/> class.
/// </summary>
/// <param name="chatClient">The AI chat client to be used.</param>
public EnricherOptions(IChatClient chatClient)
{
ChatClient = Throw.IfNull(chatClient);
}

/// <summary>
/// Gets the AI chat client to be used.
/// </summary>
public IChatClient ChatClient { get; }

/// <summary>
/// Gets or sets the options for the <see cref="ChatClient"/>.
/// </summary>
public ChatOptions? ChatOptions { get; set; }

/// <summary>
/// Gets or sets the logger factory to be used for logging.
/// </summary>
/// <remarks>
/// Enricher failures should not fail the whole ingestion pipeline, as they are best-effort enhancements.
/// This logger factory can be used to create loggers to log such failures.
/// </remarks>
public ILoggerFactory? LoggerFactory { get; set; }

/// <summary>
/// Gets or sets the batch size for processing chunks. Default is 20.
/// </summary>
public int BatchSize { get; set => field = Throw.IfLessThanOrEqual(value, 0); } = 20;

internal EnricherOptions Clone() => new(ChatClient)
{
ChatOptions = ChatOptions?.Clone(),
LoggerFactory = LoggerFactory,
BatchSize = BatchSize
};
}
Loading
Loading