Skip to content

Add GenAI core package #7177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,11 @@ Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "Microsoft.ML.FSharp.Tests",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Data.Analysis.PerformanceTests", "test\Microsoft.Data.Analysis.PerformanceTests\Microsoft.Data.Analysis.PerformanceTests.csproj", "{FB8A8823-CC6C-4C2F-8539-05FBFB7C91CD}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.TorchSharp.Tests", "test\Microsoft.ML.TorchSharp.Tests\Microsoft.ML.TorchSharp.Tests.csproj", "{AB8D68F1-6C3E-41FD-B0EC-A093E009341D}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TorchSharp.Tests", "test\Microsoft.ML.TorchSharp.Tests\Microsoft.ML.TorchSharp.Tests.csproj", "{AB8D68F1-6C3E-41FD-B0EC-A093E009341D}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.TensorFlow.Tests", "test\Microsoft.ML.TensorFlow.Tests\Microsoft.ML.TensorFlow.Tests.csproj", "{763FF013-8309-4680-A769-B54E7BB99612}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow.Tests", "test\Microsoft.ML.TensorFlow.Tests\Microsoft.ML.TensorFlow.Tests.csproj", "{763FF013-8309-4680-A769-B54E7BB99612}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.GenAI.Core", "src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj", "{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand Down Expand Up @@ -512,6 +514,14 @@ Global
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Release|Any CPU.Build.0 = Release|Any CPU
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Release|x64.ActiveCfg = Release|Any CPU
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}.Release|x64.Build.0 = Release|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Debug|x64.ActiveCfg = Debug|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Debug|x64.Build.0 = Debug|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Release|Any CPU.Build.0 = Release|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Release|x64.ActiveCfg = Release|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Release|x64.Build.0 = Release|Any CPU
{9222FC9D-599A-49A5-B685-08CC9A5C81D7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{9222FC9D-599A-49A5-B685-08CC9A5C81D7}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9222FC9D-599A-49A5-B685-08CC9A5C81D7}.Debug|x64.ActiveCfg = Debug|Any CPU
Expand Down Expand Up @@ -820,14 +830,14 @@ Global
{763FF013-8309-4680-A769-B54E7BB99612}.Release|Any CPU.Build.0 = Release|Any CPU
{763FF013-8309-4680-A769-B54E7BB99612}.Release|x64.ActiveCfg = Release|Any CPU
{763FF013-8309-4680-A769-B54E7BB99612}.Release|x64.Build.0 = Release|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Debug|x64.ActiveCfg = Debug|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Debug|x64.Build.0 = Debug|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Release|Any CPU.Build.0 = Release|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Release|x64.ActiveCfg = Release|Any CPU
{39E89702-1A46-4D5B-BA50-530D11309B5E}.Release|x64.Build.0 = Release|Any CPU
{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Debug|Any CPU.Build.0 = Debug|Any CPU
{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Debug|x64.ActiveCfg = Debug|Any CPU
{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Debug|x64.Build.0 = Debug|Any CPU
{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Release|Any CPU.ActiveCfg = Release|Any CPU
{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Release|Any CPU.Build.0 = Release|Any CPU
{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Release|x64.ActiveCfg = Release|Any CPU
{DB2CA055-8ABD-4E3E-8089-5B64C3415E85}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -874,6 +884,7 @@ Global
{11A5210E-2EA7-42F1-80DB-827762E9C781} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31} = {DA452A53-2E94-4433-B08C-041EDEC729E6}
{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{39E89702-1A46-4D5B-BA50-530D11309B5E} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{9222FC9D-599A-49A5-B685-08CC9A5C81D7} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{6C29AA9B-054B-4762-BEA5-D305B932AA80} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{4805129D-78C8-46D4-9519-0AD9B0574D6D} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
Expand Down Expand Up @@ -913,7 +924,7 @@ Global
{FB8A8823-CC6C-4C2F-8539-05FBFB7C91CD} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{AB8D68F1-6C3E-41FD-B0EC-A093E009341D} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{763FF013-8309-4680-A769-B54E7BB99612} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{39E89702-1A46-4D5B-BA50-530D11309B5E} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{DB2CA055-8ABD-4E3E-8089-5B64C3415E85} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
50 changes: 50 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Extension/CausalLMPipelineExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static TorchSharp.torch;
using TorchSharp;

namespace Microsoft.ML.GenAI.Core.Extension;

public static class CausalLMPipelineExtension
{
public static string? Generate(
this CausalLMPipeline pipeline,
string prompt,
int maxLen = 128,
float temperature = 0.7f,
float topP = 0.9f,
string[]? stopSequences = null,
int eosId = 0,
string device = "cpu",
bool bos = true,
bool eos = false,
bool echo = false)
{
using var newScope = NewDisposeScope();
var inputIds = pipeline.Tokenizer.EncodeToIds(prompt);
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: device).unsqueeze(0);
var attentionMask = torch.ones_like(inputTensor);

// set up stop token ids
// stop token ids: [[eosId], [stopSequence1], [stopSequence2], ...]
// when causal language model generates tokens, it will stop when it generates any token in stopSequences
List<int[]> stopTokenIds = [[eosId]];
if (stopSequences != null)
{
stopTokenIds.AddRange(stopSequences.Select(x => pipeline.Tokenizer.EncodeToIds(x).ToArray()));
}

(var token, var _) = pipeline.Generate(inputTensor, attentionMask, temperature: temperature, maxLen: maxLen, topP: topP, stopTokenSequence: stopTokenIds.ToArray(), echo: echo);

var tokenIds = token[0].to_type(ScalarType.Int32).data<int>().ToArray();

return pipeline.Tokenizer.Decode(tokenIds);
}
}
241 changes: 241 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.Core.Extension;
using TorchSharp;
using static TorchSharp.torch;

namespace Microsoft.ML.GenAI.Core.Extension;

internal static class ModuleExtension
{
public static long GetSizeInBytes(this nn.Module model)
{
var stateDict = model.state_dict();
long size = 0;
foreach (var (_, value) in stateDict)
{
size += value.numel() * value.element_size();
}

return size;
}

public static Dictionary<string, long> GetSizeForEachDynamicLayerInBytes(this nn.Module model)
{
var stateDict = model.named_children();
if (stateDict.Count() == 0)
{
return new();
}
else
{
var dict = new Dictionary<string, long>();

foreach (var (key, value) in stateDict)
{
if (value is IDynamicLoadModule)
{
dict[key] = value.GetSizeInBytes();
}
else
{
var subDict = value.GetSizeForEachDynamicLayerInBytes();
foreach (var (subKey, subValue) in subDict)
{
dict[key + "." + subKey] = subValue;
}
}
}

return dict;
}
}

public static void ToQuantizedModule<T>(
this T model)
where T : nn.Module
{
foreach (var (_, value) in model.named_children())
{
if (value is IQuantizeModule quantizeModule)
{
quantizeModule.Quantize();
}
else
{
value.ToQuantizedModule();
}
}
}

public static T ToDynamicLoadingModel<T>(
this T model,
Dictionary<string, string> deviceMap,
string targetDevice)
where T : nn.Module
{
if (deviceMap.Count == 0)
{
model.to(new Device(targetDevice));

return model;
}

// for each module in the model, update device if it is IDynamicLoadModule
foreach (var (key, value) in model.named_children())
{
if (value is IDynamicLoadModule dynamicModule)
{
var device = deviceMap[key];
if (device != targetDevice)
{
dynamicModule.LoadToDeviceFunc = (nn.Module module) =>
{
module.to(new Device(targetDevice));
};
dynamicModule.UnloadFromDeviceFunc = (nn.Module module) =>
{
module.to(new Device(device));
};
}

value.to(new Device(device));
}
else
{
var childrenDeviceMap = deviceMap.Where(x => x.Key.StartsWith($"{key}.")).ToDictionary(x => x.Key.Substring($"{key}.".Length), x => x.Value);
value.ToDynamicLoadingModel(childrenDeviceMap, targetDevice);
}
}

return model;
}

/// <summary>
/// Infer the device map for each layer in the model.
/// The device map is a dictionary where the key is the device id (e.g. "cuda:0") and the value is the memory size in bytes of the device.
/// When inferring the device map, each layer in the model will be placed on the device in the order of the devices list.
/// </summary>
/// <param name="model"></param>
/// <param name="devices">a list of device ids (e.g. ["cuda:0", "cpu", "disk"])</param>
/// <param name="deviceSizeMapInByte">a map where the key is the device id (e.g. "cuda:0") and the value is the memory size in bytes of the device</param>
/// <returns></returns>
public static Dictionary<string, string> InferDeviceMapForEachLayer(
this nn.Module model,
string[] devices,
Dictionary<string, long> deviceSizeMapInByte)
{
var layerSizeMap = model.GetSizeForEachDynamicLayerInBytes();
var sizeToRemainOnEachDevice = 2 * layerSizeMap.Max(x => x.Value);
var deviceMap = new Dictionary<string, string>();
foreach (var device in devices)
{
long size = deviceSizeMapInByte[device];
var remainingLayerSizeMap = layerSizeMap.Where(x => !deviceMap.ContainsKey(x.Key)).ToDictionary(x => x.Key, x => x.Value);
// larger layer fit first
foreach (var (key, value) in remainingLayerSizeMap.OrderByDescending(x => x.Value))
{
if (size >= value)
{
deviceMap[key] = device;
size -= value;
}

if (size < sizeToRemainOnEachDevice)
{
break;
}
}
}

return deviceMap;
}

public static string Peek(this nn.Module model)
{
var sb = new StringBuilder();
var stateDict = model.state_dict();
// preview state_dict
int i = 0;
foreach (var (key, value) in stateDict.OrderBy(x => x.Key, StringComparer.OrdinalIgnoreCase))
{
var str = value.Peek(key);
sb.AppendLine($"{i}: {str}");
i++;
}

var res = sb.ToString();

return res;
}

public static string PeekShape(this nn.Module model)
{
var sb = new StringBuilder();
var stateDict = model.state_dict();
// preview state_dict
int i = 0;
foreach (var (key, value) in stateDict.OrderBy(x => x.Key, StringComparer.OrdinalIgnoreCase))
{
// shape str: [x, y, z]
var shapeStr = string.Join(", ", value.shape);
sb.AppendLine($"{i}: {key} shape: [{shapeStr}]");
i++;
}

var res = sb.ToString();

return res;
}

public static void LoadStateDict(this Dictionary<string, Tensor> dict, string location)
{
using FileStream stream = File.OpenRead(location);
using BinaryReader reader = new BinaryReader(stream);
var num = reader.Decode();
for (int i = 0; i < num; i++)
{
var key = reader.ReadString();
Tensor tensor = dict[key];

var originalDevice = tensor.device;
var originalType = tensor.dtype;
if (tensor.dtype == ScalarType.BFloat16)
{
tensor = tensor.to_type(ScalarType.Float32);
}

TensorExtensionMethods.Load(ref tensor!, reader, skip: false);

tensor = tensor!.to_type(originalType);
dict[key] = tensor;
}
}

public static long Decode(this BinaryReader reader)
{
long num = 0L;
int num2 = 0;
while (true)
{
long num3 = reader.ReadByte();
num += (num3 & 0x7F) << num2 * 7;
if ((num3 & 0x80) == 0L)
{
break;
}

num2++;
}

return num;
}
}
33 changes: 33 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Extension/TensorExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using TorchSharp;
using static TorchSharp.torch;

namespace Microsoft.ML.GenAI.Core.Extension;

internal static class TensorExtension
{
public static string Peek(this Tensor tensor, string id, int n = 10)
{
var device = tensor.device;
var dType = tensor.dtype;
// if type is fp16, convert to fp32
if (tensor.dtype == ScalarType.Float16)
{
tensor = tensor.to_type(ScalarType.Float32);
}
tensor = tensor.cpu();
var shapeString = string.Join(',', tensor.shape);
var tensor1D = tensor.reshape(-1);
var tensorIndex = torch.arange(tensor1D.shape[0], dtype: ScalarType.Float32).to(tensor1D.device).sqrt();
var avg = (tensor1D * tensorIndex).sum();
avg = avg / tensor1D.sum();
// keep four decimal places
avg = avg.round(4);
var str = $"{id}: sum: {avg.ToSingle()} dType: {dType} shape: [{shapeString}]";

return str;
}
}
Loading