Skip to content

Commit

Permalink
[GenAI] Introduce CausalLMPipelineChatClient for MEAI.IChatClient (do…
Browse files Browse the repository at this point in the history
…tnet#7270)

* leverage MEAI abstraction

* Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* Update src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatClient.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>

* fix comments

* Update Microsoft.ML.GenAI.Core.csproj

---------

Co-authored-by: Stephen Toub <stoub@microsoft.com>
  • Loading branch information
2 people authored and michaelgsharp committed Nov 11, 2024
1 parent 2a879f6 commit 6ed398e
Show file tree
Hide file tree
Showing 16 changed files with 534 additions and 82 deletions.
54 changes: 54 additions & 0 deletions docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Llama3_1.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using Microsoft.ML.GenAI.LLaMA;
using Microsoft.ML.Tokenizers;
using TorchSharp;
using static TorchSharp.torch;

namespace Microsoft.ML.GenAI.Samples.MEAI;

internal class Llama3_1
{
public static async Task RunAsync(string weightFolder, string checkPointName = "model.safetensors.index.json")
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}

var defaultType = ScalarType.BFloat16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder, "original");

Console.WriteLine("Loading Llama from huggingface model weight folder");
var stopWatch = System.Diagnostics.Stopwatch.StartNew();
stopWatch.Start();
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true);

var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);

var client = new Llama3CausalLMChatClient(pipeline);

var task = """
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
""";
var chatMessage = new ChatMessage(ChatRole.User, task);

await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
{
Console.Write(response.Text);
}
}
}
44 changes: 44 additions & 0 deletions docs/samples/Microsoft.ML.GenAI.Samples/MEAI/Phi3.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static TorchSharp.torch;
using TorchSharp;
using Microsoft.ML.GenAI.Phi;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;
using Microsoft.Extensions.AI;

namespace Microsoft.ML.GenAI.Samples.MEAI;

internal class Phi3
{
public static async Task RunAsync(string weightFolder)
{
var device = "cuda";
if (device == "cuda")
{
torch.InitializeDeviceType(DeviceType.CUDA);
}

var defaultType = ScalarType.Float16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
var client = new Phi3CausalLMChatClient(pipeline);

var task = """
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
""";
var chatMessage = new ChatMessage(ChatRole.User, task);

await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
{
Console.Write(response.Text);
}
}
}
4 changes: 3 additions & 1 deletion docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.GenAI.Samples.Llama;
using Microsoft.ML.GenAI.Samples.MEAI;

await LlamaSample.RunLlama(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-3B-Instruct");
//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
3 changes: 2 additions & 1 deletion eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
<SystemRuntimeCompilerServicesUnsafeVersion>6.0.0</SystemRuntimeCompilerServicesUnsafeVersion>
<SystemSecurityPrincipalWindows>5.0.0</SystemSecurityPrincipalWindows>
<SystemTextEncodingsWebVersion>8.0.0</SystemTextEncodingsWebVersion>
<SystemTextJsonVersion>8.0.4</SystemTextJsonVersion>
<SystemTextJsonVersion>8.0.5</SystemTextJsonVersion>
<SystemThreadingChannelsVersion>8.0.0</SystemThreadingChannelsVersion>
<!-- Other product dependencies -->
<ApacheArrowVersion>14.0.2</ApacheArrowVersion>
Expand All @@ -47,6 +47,7 @@
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.24375.2</MicrosoftDotNetInteractiveVersion>
<MicrosoftMLOnnxRuntimeVersion>1.18.1</MicrosoftMLOnnxRuntimeVersion>
<MlNetMklDepsVersion>0.0.0.12</MlNetMklDepsVersion>
<MicrosoftExtensionsAIVersion>9.0.0-preview.9.24507.7</MicrosoftExtensionsAIVersion>
<!--
@("inteltbb.devel", "win", "2021.7.1.15305")
-->
Expand Down
89 changes: 89 additions & 0 deletions src/Microsoft.ML.GenAI.Core/CausalLMPipelineChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.ML.Tokenizers;
using static TorchSharp.torch;

namespace Microsoft.ML.GenAI.Core;

public abstract class CausalLMPipelineChatClient<TTokenizer, TCausalLMModel> : IChatClient
where TTokenizer : Tokenizer
where TCausalLMModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
private readonly ICausalLMPipeline<TTokenizer, TCausalLMModel> _pipeline;
private readonly IMEAIChatTemplateBuilder _chatTemplateBuilder;

public CausalLMPipelineChatClient(
ICausalLMPipeline<TTokenizer, TCausalLMModel> pipeline,
IMEAIChatTemplateBuilder chatTemplateBuilder,
ChatClientMetadata? metadata = null)
{
var classNameWithType = $"{nameof(CausalLMPipelineChatClient<TTokenizer, TCausalLMModel>)}<{typeof(TTokenizer).Name}, {typeof(TCausalLMModel).Name}>";
Metadata ??= new ChatClientMetadata(providerName: classNameWithType, modelId: typeof(TCausalLMModel).Name);
_chatTemplateBuilder = chatTemplateBuilder;
_pipeline = pipeline;
}

public ChatClientMetadata Metadata { get; }

public virtual Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
var stopSequences = options?.StopSequences ?? Array.Empty<string>();

var output = _pipeline.Generate(
prompt,
maxLen: options?.MaxOutputTokens ?? 1024,
temperature: options?.Temperature ?? 0.7f,
stopSequences: stopSequences.ToArray()) ?? throw new InvalidOperationException("Failed to generate a reply.");

var chatMessage = new ChatMessage(ChatRole.Assistant, output);
return Task.FromResult(new ChatCompletion([chatMessage])
{
CreatedAt = DateTime.UtcNow,
FinishReason = ChatFinishReason.Stop,
});
}

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public virtual async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
var stopSequences = options?.StopSequences ?? Array.Empty<string>();

foreach (var output in _pipeline.GenerateStreaming(
prompt,
maxLen: options?.MaxOutputTokens ?? 1024,
temperature: options?.Temperature ?? 0.7f,
stopSequences: stopSequences.ToArray()))
{
yield return new StreamingChatCompletionUpdate
{
Role = ChatRole.Assistant,
Text = output,
CreatedAt = DateTime.UtcNow,
};
}
}

public virtual void Dispose()
{
}

public virtual TService? GetService<TService>(object? key = null) where TService : class
{
return null;
}
}
1 change: 1 addition & 0 deletions src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

<ItemGroup>
<PackageReference Include="AutoGen.Core" Version="$(AutoGenVersion)" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="$(MicrosoftExtensionsAIVersion)" />
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" Version="$(SemanticKernelVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
Expand Down
6 changes: 6 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text;
using System.Threading.Tasks;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.ML.GenAI.Core;
Expand All @@ -22,6 +23,11 @@ public interface IAutoGenChatTemplateBuilder
string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null);
}

public interface IMEAIChatTemplateBuilder
{
string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null);
}

public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder
{
}
57 changes: 57 additions & 0 deletions src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.Tokenizers;

namespace Microsoft.ML.GenAI.LLaMA;

public class Llama3CausalLMChatClient : CausalLMPipelineChatClient<Tokenizer, LlamaForCausalLM>
{
private readonly string _eotToken = "<|eot_id|>";

public Llama3CausalLMChatClient(
ICausalLMPipeline<Tokenizer, LlamaForCausalLM> pipeline,
IMEAIChatTemplateBuilder? chatTemplateBuilder = null,
ChatClientMetadata? metadata = null)
: base(
pipeline,
chatTemplateBuilder ?? Llama3_1ChatTemplateBuilder.Instance,
metadata ?? new ChatClientMetadata(modelId: nameof(Llama3CausalLMChatClient)))
{
}

public override Task<ChatCompletion> CompleteAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();

if (options.StopSequences != null)
{
options.StopSequences.Add(_eotToken);
}
else
{
options.StopSequences = new List<string> { _eotToken };
}

return base.CompleteAsync(chatMessages, options, cancellationToken);
}

public override IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
CancellationToken cancellationToken = default)
{
options ??= new ChatOptions();
options.StopSequences ??= [];
options.StopSequences.Add(_eotToken);

return base.CompleteStreamingAsync(chatMessages, options, cancellationToken);
}
}
38 changes: 37 additions & 1 deletion src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

using System.Text;
using AutoGen.Core;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using TextContent = Microsoft.SemanticKernel.TextContent;

namespace Microsoft.ML.GenAI.LLaMA;
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBuilder
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
private const char Newline = '\n';
Expand Down Expand Up @@ -86,5 +88,39 @@ public string BuildPrompt(ChatHistory chatHistory)
return sb.ToString();
}

public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null)
{
var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant };
if (messages.Any(m => m.Text is null))
{
throw new InvalidOperationException("Please provide a message with content.");
}

if (messages.Any(m => availableRoles.Any(availableRole => availableRole == m.Role) == false))
{
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
}

var sb = new StringBuilder();
sb.Append("<|begin_of_text|>");
foreach (var message in messages)
{
var role = message.Role.Value;
var content = message.Text!;
sb.Append(message switch
{
_ when message.Role == ChatRole.System => $"<|start_header_id|>system<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
_ when message.Role == ChatRole.User => $"<|start_header_id|>user<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
_ when message.Role == ChatRole.Assistant => $"<|start_header_id|>assistant<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
_ => throw new InvalidOperationException("Invalid role.")
});
}

sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}");
var input = sb.ToString();

return input;
}

public static Llama3_1ChatTemplateBuilder Instance { get; } = new Llama3_1ChatTemplateBuilder();
}
Loading

0 comments on commit 6ed398e

Please sign in to comment.