Skip to content

Commit

Permalink
.Net: Add GenAI support to Connectors.Onnx (microsoft#6518)
Browse files Browse the repository at this point in the history
@stephentoub invite me to Contribute
https://github.com/feiyun0112/SemanticKernel.Connectors.OnnxRuntimeGenAI
to Microsoft.SemanticKernel.Connectors.Onnx


[https://github.com/feiyun0112/SemanticKernel.Connectors.OnnxRuntimeGenAI/issues/4](https://github.com/feiyun0112/SemanticKernel.Connectors.OnnxRuntimeGenAI/issues/4)

---------

Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com>
Co-authored-by: Stephen Toub <stoub@microsoft.com>
Co-authored-by: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 23, 2024
1 parent f703783 commit ffeb099
Show file tree
Hide file tree
Showing 8 changed files with 517 additions and 0 deletions.
4 changes: 4 additions & 0 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,9 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<!-- OnnxRuntimeGenAI -->
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI" Version="0.3.0"/>
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Version="0.3.0"/>
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Version="0.3.0"/>
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Onnx;
using Xunit;

namespace SemanticKernel.Connectors.Onnx.UnitTests;

/// <summary>
/// Unit tests for <see cref="OnnxKernelBuilderExtensions"/>.
/// </summary>
public class OnnxExtensionsTests
{
[Fact]
public void AddOnnxRuntimeGenAIChatCompletionToServiceCollection()
{
// Arrange
var collection = new ServiceCollection();
collection.AddOnnxRuntimeGenAIChatCompletion("modelId", "modelPath");

// Act
var kernelBuilder = collection.AddKernel();
var kernel = collection.BuildServiceProvider().GetRequiredService<Kernel>();
var service = kernel.GetRequiredService<IChatCompletionService>();

// Assert
Assert.NotNull(service);
Assert.IsType<OnnxRuntimeGenAIChatCompletionService>(service);
}

[Fact]
public void AddOnnxRuntimeGenAIChatCompletionToKernelBuilder()
{
// Arrange
var collection = new ServiceCollection();
var kernelBuilder = collection.AddKernel();
kernelBuilder.AddOnnxRuntimeGenAIChatCompletion("modelId", "modelPath");

// Act
var kernel = collection.BuildServiceProvider().GetRequiredService<Kernel>();
var service = kernel.GetRequiredService<IChatCompletionService>();

// Assert
Assert.NotNull(service);
Assert.IsType<OnnxRuntimeGenAIChatCompletionService>(service);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.Onnx;
using Xunit;

namespace SemanticKernel.Connectors.Onnx.UnitTests;

/// <summary>
/// Unit tests for <see cref="OnnxRuntimeGenAIPromptExecutionSettings"/>.
/// </summary>
public class OnnxRuntimeGenAIPromptExecutionSettingsTests
{
[Fact]
public void FromExecutionSettingsWhenAlreadyMistralShouldReturnSame()
{
// Arrange
var executionSettings = new OnnxRuntimeGenAIPromptExecutionSettings();

// Act
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);

// Assert
Assert.Same(executionSettings, onnxExecutionSettings);
}

[Fact]
public void FromExecutionSettingsWhenNullShouldReturnDefaultSettings()
{
// Arrange
PromptExecutionSettings? executionSettings = null;

// Act
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);

// Assert
Assert.Null(onnxExecutionSettings.TopK);
Assert.Null(onnxExecutionSettings.TopP);
Assert.Null(onnxExecutionSettings.Temperature);
Assert.Null(onnxExecutionSettings.RepetitionPenalty);
Assert.Null(onnxExecutionSettings.PastPresentShareBuffer);
Assert.Null(onnxExecutionSettings.NumReturnSequences);
Assert.Null(onnxExecutionSettings.NumBeams);
Assert.Null(onnxExecutionSettings.NoRepeatNgramSize);
Assert.Null(onnxExecutionSettings.MinTokens);
Assert.Null(onnxExecutionSettings.MaxTokens);
Assert.Null(onnxExecutionSettings.LengthPenalty);
Assert.Null(onnxExecutionSettings.DiversityPenalty);
Assert.Null(onnxExecutionSettings.EarlyStopping);
Assert.Null(onnxExecutionSettings.DoSample);
}

[Fact]
public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecialized()
{
// Arrange
string jsonSettings = """
{
"top_k": 2,
"top_p": 0.9,
"temperature": 0.5,
"repetition_penalty": 0.1,
"past_present_share_buffer": true,
"num_return_sequences": 200,
"num_beams": 20,
"no_repeat_ngram_size": 15,
"min_tokens": 10,
"max_tokens": 100,
"length_penalty": 0.2,
"diversity_penalty": 0.3,
"early_stopping": false,
"do_sample": true
}
""";

// Act
var executionSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(jsonSettings);
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);

// Assert
Assert.Equal(2, onnxExecutionSettings.TopK);
Assert.Equal(0.9f, onnxExecutionSettings.TopP);
Assert.Equal(0.5f, onnxExecutionSettings.Temperature);
Assert.Equal(0.1f, onnxExecutionSettings.RepetitionPenalty);
Assert.True(onnxExecutionSettings.PastPresentShareBuffer);
Assert.Equal(200, onnxExecutionSettings.NumReturnSequences);
Assert.Equal(20, onnxExecutionSettings.NumBeams);
Assert.Equal(15, onnxExecutionSettings.NoRepeatNgramSize);
Assert.Equal(10, onnxExecutionSettings.MinTokens);
Assert.Equal(100, onnxExecutionSettings.MaxTokens);
Assert.Equal(0.2f, onnxExecutionSettings.LengthPenalty);
Assert.Equal(0.3f, onnxExecutionSettings.DiversityPenalty);
Assert.False(onnxExecutionSettings.EarlyStopping);
Assert.True(onnxExecutionSettings.DoSample);
}
}
6 changes: 6 additions & 0 deletions dotnet/src/Connectors/Connectors.Onnx/Connectors.Onnx.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@
<PackageReference Include="System.Numerics.Tensors" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI" Condition=" '$(Configuration)' == 'Debug' OR '$(Configuration)' == 'Release' " />
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Condition=" '$(Configuration)' == 'Debug_Cuda' OR '$(Configuration)' == 'Release_Cuda' " />
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Condition=" '$(Configuration)' == 'Debug_DirectML' OR '$(Configuration)' == 'Release_DirectML' " />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

using System.IO;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Onnx;
using Microsoft.SemanticKernel.Embeddings;

Expand All @@ -14,6 +16,29 @@ namespace Microsoft.SemanticKernel;
/// </summary>
public static class OnnxKernelBuilderExtensions
{
/// <summary>
/// Add OnnxRuntimeGenAI Chat Completion services to the kernel builder.
/// </summary>
/// <param name="builder">The kernel builder.</param>
/// <param name="modelId">Model Id.</param>
/// <param name="modelPath">The generative AI ONNX model path.</param>
/// <param name="serviceId">The optional service ID.</param>
/// <returns>The updated kernel builder.</returns>
public static IKernelBuilder AddOnnxRuntimeGenAIChatCompletion(
this IKernelBuilder builder,
string modelId,
string modelPath,
string? serviceId = null)
{
builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
new OnnxRuntimeGenAIChatCompletionService(
modelId,
modelPath: modelPath,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));

return builder;
}

/// <summary>Adds a text embedding generation service using a BERT ONNX model.</summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="onnxModelPath">The path to the ONNX model file.</param>
Expand Down
Loading

0 comments on commit ffeb099

Please sign in to comment.