Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

LoRA loading and LoRA application #20

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions OnnxStack.Console/Examples/LoRADebug.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using Microsoft.ML.OnnxRuntime;
using OnnxStack.Core.Services;

namespace OnnxStack.Console.Runner
{
public sealed class LoRADebug : IExampleRunner
{
private readonly string _outputDirectory;
private readonly IOnnxModelService _modelService;
private readonly IOnnxModelAdaptaterService _modelAdaptaterService;

public LoRADebug(IOnnxModelService modelService)
{
_modelService = modelService;
_outputDirectory = Path.Combine(Directory.GetCurrentDirectory(), "Examples", nameof(StableDebug));
}

public string Name => "LoRA Debug";

public string Description => "LoRA Debugger";

public async Task RunAsync()
{
string modelPath = "D:\\Repositories\\stable-diffusion-v1-5\\unet\\model.onnx";
string loraModelPath = "D:\\Repositories\\LoRAFiles\\model.onnx";

using (var modelession = new InferenceSession(modelPath))
using (var loraModelSession = new InferenceSession(loraModelPath))
{
try
{
_modelAdaptaterService.ApplyLowRankAdaptation(modelession, loraModelSession);
}
catch (Exception ex)
{

}
}
}

}
}
14 changes: 14 additions & 0 deletions OnnxStack.Core/Model/OnnxModelAdapter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using OnnxStack.Core.Config;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace OnnxStack.Core.Model
{
public class OnnxModelAdapter : IOnnxModel
{
public string Name { get; set; }
}
}
1 change: 1 addition & 0 deletions OnnxStack.Core/Registration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static void AddOnnxStack(this IServiceCollection serviceCollection)
{
serviceCollection.AddSingleton(ConfigManager.LoadConfiguration());
serviceCollection.AddSingleton<IOnnxModelService, OnnxModelService>();
serviceCollection.AddSingleton<IOnnxModelAdaptaterService, OnnxModelAdaptaterService>();
}


Expand Down
9 changes: 9 additions & 0 deletions OnnxStack.Core/Services/IOnnxModelAdaptaterService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Microsoft.ML.OnnxRuntime;

namespace OnnxStack.Core.Services
{
public interface IOnnxModelAdaptaterService
{
void ApplyLowRankAdaptation(InferenceSession primarySession, InferenceSession loraSession);
}
}
50 changes: 50 additions & 0 deletions OnnxStack.Core/Services/OnnxModelAdaptaterService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Linq;

namespace OnnxStack.Core.Services
{
public class OnnxModelAdaptaterService : IOnnxModelAdaptaterService
{
public void ApplyLowRankAdaptation(InferenceSession primarySession, InferenceSession loraSession)
{
// For simplicity, let's assume we will replace the weights of the first dense layer
string layerName = "layer_name";

// Get the current weights from the primary model
var primaryInputName = primarySession.InputMetadata.Keys.First();
var primaryInputTensor = primarySession.InputMetadata[primaryInputName];
var primaryWeights = new float[primaryInputTensor.Dimensions.Product()];

// Get the weights from the LoRA model
var lraInputName = loraSession.InputMetadata.Keys.First();
var lraInputTensor = loraSession.InputMetadata[lraInputName];
var lraWeights = new float[lraInputTensor.Dimensions.Product()];

// Apply LoRA (replace weights) this is where we will do the mutiplication of the weights
// but for testing sake just brute for replacing
Array.Copy(lraWeights, primaryWeights, Math.Min(primaryWeights.Length, lraWeights.Length));

// Update the primary model tensor with the modified weights
var tensor = new DenseTensor<float>(primaryWeights, primaryInputTensor.Dimensions.ToArray());
var inputs = new NamedOnnxValue[] { NamedOnnxValue.CreateFromTensor(primaryInputName, tensor) };

// Will it run?
primarySession.Run(inputs);
}
}

public static class Ext
{
public static int Product(this int[] array)
{
int result = 1;
foreach (int element in array)
{
result *= element;
}
return result;
}
}
}