Skip to content

Commit 09094ae

Browse files
EmbeddingGeneratorBuilder API updates (#5647)
1 parent aa6e8f0 commit 09094ae

File tree

15 files changed

+131
-90
lines changed

15 files changed

+131
-90
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,11 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder()
432432

433433
// Explore changing the order of the intermediate "Use" calls to see that impact
434434
// that has on what gets cached, traced, etc.
435-
IEmbeddingGenerator<string, Embedding<float>> generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>()
435+
var generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(
436+
new SampleEmbeddingGenerator(new Uri("http://coolsite.ai"), "my-custom-model"))
436437
.UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())))
437438
.UseOpenTelemetry(sourceName)
438-
.Use(new SampleEmbeddingGenerator(new Uri("http://coolsite.ai"), "my-custom-model"));
439+
.Build();
439440

440441
var embeddings = await generator.GenerateAsync(
441442
[

src/Libraries/Microsoft.Extensions.AI.Ollama/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,9 @@ IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDi
210210
IEmbeddingGenerator<string, Embedding<float>> ollamaGenerator =
211211
new OllamaEmbeddingGenerator(new Uri("http://localhost:11434/"), "all-minilm");
212212

213-
IEmbeddingGenerator<string, Embedding<float>> generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>()
213+
IEmbeddingGenerator<string, Embedding<float>> generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(ollamaGenerator)
214214
.UseDistributedCache(cache)
215-
.Use(ollamaGenerator);
215+
.Build();
216216

217217
foreach (var prompt in new[] { "What is AI?", "What is .NET?", "What is AI?" })
218218
{
@@ -256,8 +256,7 @@ var builder = WebApplication.CreateBuilder(args);
256256
builder.Services.AddChatClient(
257257
new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"));
258258

259-
builder.Services.AddEmbeddingGenerator<string,Embedding<float>>(g =>
260-
g.Use(new OllamaEmbeddingGenerator(endpoint, "all-minilm")));
259+
builder.Services.AddEmbeddingGenerator(new OllamaEmbeddingGenerator(endpoint, "all-minilm"));
261260

262261
var app = builder.Build();
263262

src/Libraries/Microsoft.Extensions.AI.OpenAI/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ IEmbeddingGenerator<string, Embedding<float>> openAIGenerator =
233233
new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))
234234
.AsEmbeddingGenerator("text-embedding-3-small");
235235

236-
IEmbeddingGenerator<string, Embedding<float>> generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>()
236+
IEmbeddingGenerator<string, Embedding<float>> generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(openAIGenerator)
237237
.UseDistributedCache(cache)
238-
.Use(openAIGenerator);
238+
.Build();
239239

240240
foreach (var prompt in new[] { "What is AI?", "What is .NET?", "What is AI?" })
241241
{
@@ -284,8 +284,8 @@ builder.Services.AddSingleton(new OpenAIClient(builder.Configuration["OPENAI_API
284284
builder.Services.AddChatClient(services =>
285285
services.GetRequiredService<OpenAIClient>().AsChatClient("gpt-4o-mini"));
286286

287-
builder.Services.AddEmbeddingGenerator<string, Embedding<float>>(g =>
288-
g.Use(g.Services.GetRequiredService<OpenAIClient>().AsEmbeddingGenerator("text-embedding-3-small")));
287+
builder.Services.AddEmbeddingGenerator(services =>
288+
services.GetRequiredService<OpenAIClient>().AsEmbeddingGenerator("text-embedding-3-small"));
289289

290290
var app = builder.Build();
291291

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace Microsoft.Extensions.AI;
1010
/// <summary>A builder for creating pipelines of <see cref="IChatClient"/>.</summary>
1111
public sealed class ChatClientBuilder
1212
{
13-
private Func<IServiceProvider, IChatClient> _innerClientFactory;
13+
private readonly Func<IServiceProvider, IChatClient> _innerClientFactory;
1414

1515
/// <summary>The registered client factory instances.</summary>
1616
private List<Func<IServiceProvider, IChatClient, IChatClient>>? _clientFactories;
@@ -30,7 +30,7 @@ public ChatClientBuilder(Func<IServiceProvider, IChatClient> innerClientFactory)
3030
_innerClientFactory = Throw.IfNull(innerClientFactory);
3131
}
3232

33-
/// <summary>Returns an <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</summary>
33+
/// <summary>Builds an <see cref="IChatClient"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.</summary>
3434
/// <param name="services">
3535
/// The <see cref="IServiceProvider"/> that should provide services to the <see cref="IChatClient"/> instances.
3636
/// If null, an empty <see cref="IServiceProvider"/> will be used.

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public static ChatClientBuilder AddChatClient(
3737
return builder;
3838
}
3939

40-
/// <summary>Registers a singleton <see cref="IChatClient"/> in the <see cref="IServiceCollection"/>.</summary>
40+
/// <summary>Registers a keyed singleton <see cref="IChatClient"/> in the <see cref="IServiceCollection"/>.</summary>
4141
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
4242
/// <param name="serviceKey">The key with which to associate the client.</param>
4343
/// <param name="innerClient">The inner <see cref="IChatClient"/> that represents the underlying backend.</param>
@@ -49,7 +49,7 @@ public static ChatClientBuilder AddKeyedChatClient(
4949
IChatClient innerClient)
5050
=> AddKeyedChatClient(serviceCollection, serviceKey, _ => innerClient);
5151

52-
/// <summary>Registers a singleton <see cref="IChatClient"/> in the <see cref="IServiceCollection"/>.</summary>
52+
/// <summary>Registers a keyed singleton <see cref="IChatClient"/> in the <see cref="IServiceCollection"/>.</summary>
5353
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the client should be added.</param>
5454
/// <param name="serviceKey">The key with which to associate the client.</param>
5555
/// <param name="innerClientFactory">A callback that produces the inner <see cref="IChatClient"/> that represents the underlying backend.</param>

src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,45 @@ namespace Microsoft.Extensions.AI;
1313
public sealed class EmbeddingGeneratorBuilder<TInput, TEmbedding>
1414
where TEmbedding : Embedding
1515
{
16+
private readonly Func<IServiceProvider, IEmbeddingGenerator<TInput, TEmbedding>> _innerGeneratorFactory;
17+
1618
/// <summary>The registered client factory instances.</summary>
1719
private List<Func<IServiceProvider, IEmbeddingGenerator<TInput, TEmbedding>, IEmbeddingGenerator<TInput, TEmbedding>>>? _generatorFactories;
1820

1921
/// <summary>Initializes a new instance of the <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> class.</summary>
20-
/// <param name="services">The service provider to use for dependency injection.</param>
21-
public EmbeddingGeneratorBuilder(IServiceProvider? services = null)
22+
/// <param name="innerGenerator">The inner <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> that represents the underlying backend.</param>
23+
public EmbeddingGeneratorBuilder(IEmbeddingGenerator<TInput, TEmbedding> innerGenerator)
2224
{
23-
Services = services ?? EmptyServiceProvider.Instance;
25+
_ = Throw.IfNull(innerGenerator);
26+
_innerGeneratorFactory = _ => innerGenerator;
2427
}
2528

26-
/// <summary>Gets the <see cref="IServiceProvider"/> associated with the builder instance.</summary>
27-
public IServiceProvider Services { get; }
29+
/// <summary>Initializes a new instance of the <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> class.</summary>
30+
/// <param name="innerGeneratorFactory">A callback that produces the inner <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> that represents the underlying backend.</param>
31+
public EmbeddingGeneratorBuilder(Func<IServiceProvider, IEmbeddingGenerator<TInput, TEmbedding>> innerGeneratorFactory)
32+
{
33+
_innerGeneratorFactory = Throw.IfNull(innerGeneratorFactory);
34+
}
2835

2936
/// <summary>
30-
/// Builds an instance of <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> using the specified inner generator.
37+
/// Builds an <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn.
3138
/// </summary>
32-
/// <param name="innerGenerator">The inner generator to use.</param>
33-
/// <returns>An instance of <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>.</returns>
34-
/// <remarks>
35-
/// If there are any factories registered with this builder, <paramref name="innerGenerator"/> is used as a seed to
36-
/// the last factory, and the result of each factory delegate is passed to the previously registered factory.
37-
/// The final result is then returned from this call.
38-
/// </remarks>
39-
public IEmbeddingGenerator<TInput, TEmbedding> Use(IEmbeddingGenerator<TInput, TEmbedding> innerGenerator)
39+
/// <param name="services">
40+
/// The <see cref="IServiceProvider"/> that should provide services to the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> instances.
41+
/// If null, an empty <see cref="IServiceProvider"/> will be used.
42+
/// </param>
43+
/// <returns>An instance of <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> that represents the entire pipeline.</returns>
44+
public IEmbeddingGenerator<TInput, TEmbedding> Build(IServiceProvider? services = null)
4045
{
41-
var embeddingGenerator = Throw.IfNull(innerGenerator);
46+
services ??= EmptyServiceProvider.Instance;
47+
var embeddingGenerator = _innerGeneratorFactory(services);
4248

4349
// To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost.
4450
if (_generatorFactories is not null)
4551
{
4652
for (var i = _generatorFactories.Count - 1; i >= 0; i--)
4753
{
48-
embeddingGenerator = _generatorFactories[i](Services, embeddingGenerator) ??
54+
embeddingGenerator = _generatorFactories[i](services, embeddingGenerator) ??
4955
throw new InvalidOperationException(
5056
$"The {nameof(IEmbeddingGenerator<TInput, TEmbedding>)} entry at index {i} returned null. " +
5157
$"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IEmbeddingGenerator<TInput, TEmbedding>)} instances.");

src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,44 +10,74 @@ namespace Microsoft.Extensions.DependencyInjection;
1010
/// <summary>Provides extension methods for registering <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> with a <see cref="IServiceCollection"/>.</summary>
1111
public static class EmbeddingGeneratorBuilderServiceCollectionExtensions
1212
{
13-
/// <summary>Adds a embedding generator to the <see cref="IServiceCollection"/>.</summary>
13+
/// <summary>Registers a singleton embedding generator in the <see cref="IServiceCollection"/>.</summary>
1414
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
1515
/// <typeparam name="TEmbedding">The type of embeddings to generate.</typeparam>
16-
/// <param name="services">The <see cref="IServiceCollection"/> to which the generator should be added.</param>
17-
/// <param name="generatorFactory">The factory to use to construct the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> instance.</param>
18-
/// <returns>The <paramref name="services"/> collection.</returns>
19-
/// <remarks>The generator is registered as a scoped service.</remarks>
20-
public static IServiceCollection AddEmbeddingGenerator<TInput, TEmbedding>(
21-
this IServiceCollection services,
22-
Func<EmbeddingGeneratorBuilder<TInput, TEmbedding>, IEmbeddingGenerator<TInput, TEmbedding>> generatorFactory)
16+
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the generator should be added.</param>
17+
/// <param name="innerGenerator">The inner <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> that represents the underlying backend.</param>
18+
/// <returns>An <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> that can be used to build a pipeline around the inner generator.</returns>
19+
/// <remarks>The generator is registered as a singleton service.</remarks>
20+
public static EmbeddingGeneratorBuilder<TInput, TEmbedding> AddEmbeddingGenerator<TInput, TEmbedding>(
21+
this IServiceCollection serviceCollection,
22+
IEmbeddingGenerator<TInput, TEmbedding> innerGenerator)
23+
where TEmbedding : Embedding
24+
=> AddEmbeddingGenerator(serviceCollection, _ => innerGenerator);
25+
26+
/// <summary>Registers a singleton embedding generator in the <see cref="IServiceCollection"/>.</summary>
27+
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
28+
/// <typeparam name="TEmbedding">The type of embeddings to generate.</typeparam>
29+
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the generator should be added.</param>
30+
/// <param name="innerGeneratorFactory">A callback that produces the inner <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> that represents the underlying backend.</param>
31+
/// <returns>An <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> that can be used to build a pipeline around the inner generator.</returns>
32+
/// <remarks>The generator is registered as a singleton service.</remarks>
33+
public static EmbeddingGeneratorBuilder<TInput, TEmbedding> AddEmbeddingGenerator<TInput, TEmbedding>(
34+
this IServiceCollection serviceCollection,
35+
Func<IServiceProvider, IEmbeddingGenerator<TInput, TEmbedding>> innerGeneratorFactory)
2336
where TEmbedding : Embedding
2437
{
25-
_ = Throw.IfNull(services);
26-
_ = Throw.IfNull(generatorFactory);
38+
_ = Throw.IfNull(serviceCollection);
39+
_ = Throw.IfNull(innerGeneratorFactory);
2740

28-
return services.AddScoped(services =>
29-
generatorFactory(new EmbeddingGeneratorBuilder<TInput, TEmbedding>(services)));
41+
var builder = new EmbeddingGeneratorBuilder<TInput, TEmbedding>(innerGeneratorFactory);
42+
_ = serviceCollection.AddSingleton(builder.Build);
43+
return builder;
3044
}
3145

32-
/// <summary>Adds an embedding generator to the <see cref="IServiceCollection"/>.</summary>
46+
/// <summary>Registers a keyed singleton embedding generator in the <see cref="IServiceCollection"/>.</summary>
47+
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
48+
/// <typeparam name="TEmbedding">The type of embeddings to generate.</typeparam>
49+
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the generator should be added.</param>
50+
/// <param name="serviceKey">The key with which to associated the generator.</param>
51+
/// <param name="innerGenerator">The inner <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> that represents the underlying backend.</param>
52+
/// <returns>An <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> that can be used to build a pipeline around the inner generator.</returns>
53+
/// <remarks>The generator is registered as a singleton service.</remarks>
54+
public static EmbeddingGeneratorBuilder<TInput, TEmbedding> AddKeyedEmbeddingGenerator<TInput, TEmbedding>(
55+
this IServiceCollection serviceCollection,
56+
object serviceKey,
57+
IEmbeddingGenerator<TInput, TEmbedding> innerGenerator)
58+
where TEmbedding : Embedding
59+
=> AddKeyedEmbeddingGenerator(serviceCollection, serviceKey, _ => innerGenerator);
60+
61+
/// <summary>Registers a keyed singleton embedding generator in the <see cref="IServiceCollection"/>.</summary>
3362
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
3463
/// <typeparam name="TEmbedding">The type of embeddings to generate.</typeparam>
35-
/// <param name="services">The <see cref="IServiceCollection"/> to which the service should be added.</param>
64+
/// <param name="serviceCollection">The <see cref="IServiceCollection"/> to which the generator should be added.</param>
3665
/// <param name="serviceKey">The key with which to associated the generator.</param>
37-
/// <param name="generatorFactory">The factory to use to construct the <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> instance.</param>
38-
/// <returns>The <paramref name="services"/> collection.</returns>
39-
/// <remarks>The generator is registered as a scoped service.</remarks>
40-
public static IServiceCollection AddKeyedEmbeddingGenerator<TInput, TEmbedding>(
41-
this IServiceCollection services,
66+
/// <param name="innerGeneratorFactory">A callback that produces the inner <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/> that represents the underlying backend.</param>
67+
/// <returns>An <see cref="EmbeddingGeneratorBuilder{TInput, TEmbedding}"/> that can be used to build a pipeline around the inner generator.</returns>
68+
/// <remarks>The generator is registered as a singleton service.</remarks>
69+
public static EmbeddingGeneratorBuilder<TInput, TEmbedding> AddKeyedEmbeddingGenerator<TInput, TEmbedding>(
70+
this IServiceCollection serviceCollection,
4271
object serviceKey,
43-
Func<EmbeddingGeneratorBuilder<TInput, TEmbedding>, IEmbeddingGenerator<TInput, TEmbedding>> generatorFactory)
72+
Func<IServiceProvider, IEmbeddingGenerator<TInput, TEmbedding>> innerGeneratorFactory)
4473
where TEmbedding : Embedding
4574
{
46-
_ = Throw.IfNull(services);
75+
_ = Throw.IfNull(serviceCollection);
4776
_ = Throw.IfNull(serviceKey);
48-
_ = Throw.IfNull(generatorFactory);
77+
_ = Throw.IfNull(innerGeneratorFactory);
4978

50-
return services.AddKeyedScoped(serviceKey, (services, _) =>
51-
generatorFactory(new EmbeddingGeneratorBuilder<TInput, TEmbedding>(services)));
79+
var builder = new EmbeddingGeneratorBuilder<TInput, TEmbedding>(innerGeneratorFactory);
80+
_ = serviceCollection.AddKeyedSingleton(serviceKey, (services, _) => builder.Build(services));
81+
return builder;
5282
}
5383
}

test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceEmbeddingGeneratorTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ public void GetService_SuccessfullyReturnsUnderlyingClient()
6363
Assert.Same(embeddingGenerator, embeddingGenerator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
6464
Assert.Same(client, embeddingGenerator.GetService<EmbeddingsClient>());
6565

66-
using IEmbeddingGenerator<string, Embedding<float>> pipeline = new EmbeddingGeneratorBuilder<string, Embedding<float>>()
66+
using IEmbeddingGenerator<string, Embedding<float>> pipeline = new EmbeddingGeneratorBuilder<string, Embedding<float>>(embeddingGenerator)
6767
.UseOpenTelemetry()
6868
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
69-
.Use(embeddingGenerator);
69+
.Build();
7070

7171
Assert.NotNull(pipeline.GetService<DistributedCachingEmbeddingGenerator<string, Embedding<float>>>());
7272
Assert.NotNull(pipeline.GetService<CachingEmbeddingGenerator<string, Embedding<float>>>());

test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ public virtual async Task Caching_SameOutputsForSameInput()
8181
{
8282
SkipIfNotEnabled();
8383

84-
using var generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>()
84+
using var generator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(CreateEmbeddingGenerator()!)
8585
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
8686
.UseCallCounting()
87-
.Use(CreateEmbeddingGenerator()!);
87+
.Build();
8888

8989
string input = "Red, White, and Blue";
9090
var embedding1 = await generator.GenerateEmbeddingAsync(input);
@@ -110,9 +110,9 @@ public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics()
110110
.AddInMemoryExporter(activities)
111111
.Build();
112112

113-
var embeddingGenerator = new EmbeddingGeneratorBuilder<string, Embedding<float>>()
113+
var embeddingGenerator = new EmbeddingGeneratorBuilder<string, Embedding<float>>(CreateEmbeddingGenerator()!)
114114
.UseOpenTelemetry(sourceName: sourceName)
115-
.Use(CreateEmbeddingGenerator()!);
115+
.Build();
116116

117117
_ = await embeddingGenerator.GenerateEmbeddingAsync("Hello, world!");
118118

0 commit comments

Comments
 (0)