Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GenAI] Introduce CausalLMPipelineChatClient for MEAI.IChatClient #7270

Merged
merged 7 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Llama3_1.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.GenAI.LLaMA;
using Microsoft.ML.Tokenizers;
using TorchSharp;
using static TorchSharp.torch;

namespace Microsoft.ML.GenAI.Samples.MEAI;

internal class Llama3_1
{
public static async Task RunAsync(string weightFolder, string checkPointName = "model.safetensors.index.json")
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}

var defaultType = ScalarType.BFloat16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder, "original");

Console.WriteLine("Loading Llama from huggingface model weight folder");
var stopWatch = System.Diagnostics.Stopwatch.StartNew();
stopWatch.Start();
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true);

var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);

var client = new Llama3CausalLMChatClient(pipeline);

var task = """
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
""";
var chatMessage = new ChatMessage(ChatRole.User, task);

await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
{
Console.Write(response.Text);
}
}
}
44 changes: 44 additions & 0 deletions docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Phi3.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static TorchSharp.torch;
using TorchSharp;
using Microsoft.ML.GenAI.Phi;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
using Microsoft.Extensions.AI;

namespace Microsoft.ML.GenAI.Samples.MEAI;

internal class Phi3
{
public static async Task RunAsync(string weightFolder)
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}

var defaultType = ScalarType.Float16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
var client = new Phi3CausalLMChatClient(pipeline);

var task = """
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
""";
var chatMessage = new ChatMessage(ChatRole.User, task);

await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
{
Console.Write(response.Text);
}
}
}
4 changes: 3 additions & 1 deletion docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.GenAI.Samples.Llama;
using Microsoft.ML.GenAI.Samples.MEAI;

await LlamaSample.RunLlama(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-3B-Instruct");
//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
3 changes: 2 additions & 1 deletion eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
<SystemRuntimeCompilerServicesUnsafeVersion>6.0.0</SystemRuntimeCompilerServicesUnsafeVersion>
<SystemSecurityPrincipalWindows>5.0.0</SystemSecurityPrincipalWindows>
<SystemTextEncodingsWebVersion>8.0.0</SystemTextEncodingsWebVersion>
<SystemTextJsonVersion>8.0.4</SystemTextJsonVersion>
<SystemTextJsonVersion>8.0.5</SystemTextJsonVersion>
<SystemThreadingChannelsVersion>8.0.0</SystemThreadingChannelsVersion>
<!-- Other product dependencies -->
<ApacheArrowVersion>14.0.2</ApacheArrowVersion>
Expand All @@ -47,6 +47,7 @@
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.24375.2</MicrosoftDotNetInteractiveVersion>
<MicrosoftMLOnnxRuntimeVersion>1.18.1</MicrosoftMLOnnxRuntimeVersion>
<MlNetMklDepsVersion>0.0.0.12</MlNetMklDepsVersion>
<MicrosoftExtensionsAIVersion>9.0.0-preview.9.24507.7</MicrosoftExtensionsAIVersion>
<!--
@("inteltbb.devel", "win", "2021.7.1.15305")
-->
Expand Down
89 changes: 89 additions & 0 deletions src/Microsoft.ML.GenAI.Core/CausalLMPipelineChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.ML.Tokenizers;
using static TorchSharp.torch;

namespace Microsoft.ML.GenAI.Core;

public abstract class CausalLMPipelineChatClient<TTokenizer, TCausalLMModel> : IChatClient
where TTokenizer : Tokenizer
where TCausalLMModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
private readonly ICausalLMPipeline<TTokenizer, TCausalLMModel> _pipeline;
private readonly IMEAIChatTemplateBuilder _chatTemplateBuilder;

public CausalLMPipelineChatClient(
ICausalLMPipeline<TTokenizer, TCausalLMModel> pipeline,
IMEAIChatTemplateBuilder chatTemplateBuilder,
ChatClientMetadata? metadata = null)
{
var classNameWithType = $"{nameof(CausalLMPipelineChatClient<TTokenizer, TCausalLMModel>)}<{typeof(TTokenizer).Name}, {typeof(TCausalLMModel).Name}>";
Metadata ??= new ChatClientMetadata(providerName: classNameWithType, modelId: typeof(TCausalLMModel).Name);
_chatTemplateBuilder = chatTemplateBuilder;
_pipeline = pipeline;
}

public ChatClientMetadata Metadata { get; }

public virtual Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
var stopSequences = options?.StopSequences ?? Array.Empty<string>();

var output = _pipeline.Generate(
prompt,
maxLen: options?.MaxOutputTokens ?? 1024,
temperature: options?.Temperature ?? 0.7f,
stopSequences: stopSequences.ToArray()) ?? throw new InvalidOperationException("Failed to generate a reply.");

var chatMessage = new ChatMessage(ChatRole.Assistant, output);
return Task.FromResult(new ChatCompletion([chatMessage])
{
CreatedAt = DateTime.UtcNow,
FinishReason = ChatFinishReason.Stop,
});
}

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public virtual async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
var stopSequences = options?.StopSequences ?? Array.Empty<string>();

foreach (var output in _pipeline.GenerateStreaming(
prompt,
maxLen: options?.MaxOutputTokens ?? 1024,
temperature: options?.Temperature ?? 0.7f,
LittleLittleCloud marked this conversation as resolved.
Show resolved Hide resolved
stopSequences: stopSequences.ToArray()))
{
yield return new StreamingChatCompletionUpdate
{
Role = ChatRole.Assistant,
Text = output,
CreatedAt = DateTime.UtcNow,
};
}
}

public virtual void Dispose()
{
}

public virtual TService? GetService<TService>(object? key = null) where TService : class
{
return null;
}
}
1 change: 1 addition & 0 deletions src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

<ItemGroup>
<PackageReference Include="AutoGen.Core" Version="$(AutoGenVersion)" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="$(MicrosoftExtensionsAIVersion)" />
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" Version="$(SemanticKernelVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
Expand Down
6 changes: 6 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text;
using System.Threading.Tasks;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.ML.GenAI.Core;
Expand All @@ -22,6 +23,11 @@ public interface IAutoGenChatTemplateBuilder
string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null);
}

public interface IMEAIChatTemplateBuilder
{
string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null);
}

public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder
{
}
57 changes: 57 additions & 0 deletions src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;

namespace Microsoft.ML.GenAI.LLaMA;

public class Llama3CausalLMChatClient : CausalLMPipelineChatClient<Tokenizer, LlamaForCausalLM>
{
private readonly string _eotToken = "<|eot_id|>";

public Llama3CausalLMChatClient(
ICausalLMPipeline<Tokenizer, LlamaForCausalLM> pipeline,
IMEAIChatTemplateBuilder? chatTemplateBuilder = null,
ChatClientMetadata? metadata = null)
: base(
pipeline,
chatTemplateBuilder ?? Llama3_1ChatTemplateBuilder.Instance,
metadata ?? new ChatClientMetadata(modelId: nameof(Llama3CausalLMChatClient)))
{
}

public override Task<ChatCompletion> CompleteAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();

if (options.StopSequences != null)
{
options.StopSequences.Add(_eotToken);
}
else
{
options.StopSequences = new List<string> { _eotToken };
}
LittleLittleCloud marked this conversation as resolved.
Show resolved Hide resolved

return base.CompleteAsync(chatMessages, options, cancellationToken);
}

public override IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();
options.StopSequences ??= [];
options.StopSequences.Add(_eotToken);

return base.CompleteStreamingAsync(chatMessages, options, cancellationToken);
}
}
38 changes: 37 additions & 1 deletion src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

using System.Text;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using TextContent = Microsoft.SemanticKernel.TextContent;

namespace Microsoft.ML.GenAI.LLaMA;
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBuilder
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
private const char Newline = '\n';
Expand Down Expand Up @@ -86,5 +88,39 @@ public string BuildPrompt(ChatHistory chatHistory)
return sb.ToString();
}

public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null)
{
var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant };
if (messages.Any(m => m.Text is null))
{
throw new InvalidOperationException("Please provide a message with content.");
}

if (messages.Any(m => availableRoles.Any(availableRole => availableRole == m.Role) == false))
{
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
}

var sb = new StringBuilder();
sb.Append("<|begin_of_text|>");
foreach (var message in messages)
{
var role = message.Role.Value;
var content = message.Text!;
sb.Append(message switch
{
_ when message.Role == ChatRole.System => $"<|start_header_id|>system<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
_ when message.Role == ChatRole.User => $"<|start_header_id|>user<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
_ when message.Role == ChatRole.Assistant => $"<|start_header_id|>assistant<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
_ => throw new InvalidOperationException("Invalid role.")
});
}

sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}");
var input = sb.ToString();

return input;
}

public static Llama3_1ChatTemplateBuilder Instance { get; } = new Llama3_1ChatTemplateBuilder();
}
Loading
Loading