Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit c1aa906

Browse files
authored
Merge pull request #31 from saddam213/OnnxNativeTypes
Initial Float16 and BFloat16 onnx type support
2 parents dd9c2c9 + a44a302 commit c1aa906

File tree

9 files changed

+230
-63
lines changed

9 files changed

+230
-63
lines changed

OnnxStack.Core/Extensions.cs

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
using Microsoft.ML.OnnxRuntime.Tensors;
33
using OnnxStack.Core.Config;
44
using System;
5+
using System.Buffers;
56
using System.Collections.Concurrent;
67
using System.Collections.Generic;
78
using System.Linq;
89
using System.Numerics;
10+
using System.Runtime.InteropServices;
911

1012
namespace OnnxStack.Core
1113
{
@@ -205,26 +207,166 @@ public static T GetBufferLength<T>(this ReadOnlySpan<T> array) where T : INumber
205207
}
206208

207209

210+
/// <summary>
211+
/// Converts to long.
212+
/// </summary>
213+
/// <param name="array">The array.</param>
214+
/// <returns></returns>
208215
public static long[] ToLong(this ReadOnlySpan<int> array)
209216
{
210217
return Array.ConvertAll(array.ToArray(), Convert.ToInt64);
211218
}
212-
219+
220+
221+
/// <summary>
222+
/// Converts the string representation of a number to an integer.
223+
/// </summary>
224+
/// <param name="array">The array.</param>
225+
/// <returns></returns>
213226
public static int[] ToInt(this long[] array)
214227
{
215228
return Array.ConvertAll(array, Convert.ToInt32);
216229
}
217230

231+
232+
/// <summary>
233+
/// Converts to long.
234+
/// </summary>
235+
/// <param name="array">The array.</param>
236+
/// <returns></returns>
218237
public static long[] ToLong(this int[] array)
219238
{
220239
return Array.ConvertAll(array, Convert.ToInt64);
221240
}
222241

223242

224-
public static OrtValue ToOrtValue<T>(this DenseTensor<T> tensor) where T : unmanaged
243+
/// <summary>
244+
/// Creates and OrtValue form the DenseTensor and NodeMetaData provided
245+
/// </summary>
246+
/// <param name="tensor">The tensor.</param>
247+
/// <param name="nodeMetadata">The node metadata.</param>
248+
/// <returns></returns>
249+
public static OrtValue ToOrtValue(this DenseTensor<float> tensor, NodeMetadata nodeMetadata)
225250
{
226-
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, tensor.Dimensions.ToLong());
251+
var dimensions = tensor.Dimensions.ToLong();
252+
return nodeMetadata.ElementDataType switch
253+
{
254+
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToFloat16(), dimensions),
255+
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer.ToBFloat16(), dimensions),
256+
_ => OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, tensor.Buffer, dimensions)
257+
};
227258
}
228259

260+
261+
/// <summary>
262+
/// Creates and allocates output tensors buffer.
263+
/// </summary>
264+
/// <param name="nodeMetadata">The node metadata.</param>
265+
/// <param name="dimensions">The dimensions.</param>
266+
/// <returns></returns>
267+
public static OrtValue CreateOutputBuffer(this NodeMetadata nodeMetadata, ReadOnlySpan<int> dimensions)
268+
{
269+
return OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, nodeMetadata.ElementDataType, dimensions.ToLong());
270+
}
271+
272+
273+
/// <summary>
274+
/// Converts to DenseTensor<float>.
275+
/// </summary>
276+
/// <param name="ortValue">The ort value.</param>
277+
/// <returns></returns>
278+
public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue)
279+
{
280+
var typeInfo = ortValue.GetTensorTypeAndShape();
281+
var dimensions = typeInfo.Shape.ToInt();
282+
return typeInfo.ElementDataType switch
283+
{
284+
TensorElementType.Float16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<Float16>().ToFloat(), dimensions),
285+
TensorElementType.BFloat16 => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat(), dimensions),
286+
_ => new DenseTensor<float>(ortValue.GetTensorDataAsSpan<float>().ToArray(), dimensions)
287+
};
288+
}
289+
290+
291+
/// <summary>
292+
/// Converts to array.
293+
/// </summary>
294+
/// <param name="ortValue">The ort value.</param>
295+
/// <returns></returns>
296+
public static float[] ToArray(this OrtValue ortValue)
297+
{
298+
var typeInfo = ortValue.GetTensorTypeAndShape();
299+
var dimensions = typeInfo.Shape.ToInt();
300+
return typeInfo.ElementDataType switch
301+
{
302+
TensorElementType.Float16 => ortValue.GetTensorDataAsSpan<Float16>().ToFloat().ToArray(),
303+
TensorElementType.BFloat16 => ortValue.GetTensorDataAsSpan<BFloat16>().ToFloat().ToArray(),
304+
_ => ortValue.GetTensorDataAsSpan<float>().ToArray()
305+
};
306+
}
307+
308+
309+
/// <summary>
310+
/// Converts to float16.
311+
/// </summary>
312+
/// <param name="inputMemory">The input memory.</param>
313+
/// <returns></returns>
314+
internal static Memory<Float16> ToFloat16(this Memory<float> inputMemory)
315+
{
316+
var elementCount = inputMemory.Length;
317+
var floatArray = new Float16[inputMemory.Length];
318+
for (int i = 0; i < elementCount; i++)
319+
floatArray[i] = (Float16)inputMemory.Span[i];
320+
321+
return floatArray.AsMemory();
322+
}
323+
324+
325+
/// <summary>
326+
/// Converts to BFloat16.
327+
/// </summary>
328+
/// <param name="inputMemory">The input memory.</param>
329+
/// <returns></returns>
330+
internal static Memory<BFloat16> ToBFloat16(this Memory<float> inputMemory)
331+
{
332+
var elementCount = inputMemory.Length;
333+
var floatArray = new BFloat16[inputMemory.Length];
334+
for (int i = 0; i < elementCount; i++)
335+
floatArray[i] = (BFloat16)inputMemory.Span[i];
336+
337+
return floatArray.AsMemory();
338+
}
339+
340+
341+
/// <summary>
342+
/// Converts to float.
343+
/// </summary>
344+
/// <param name="inputMemory">The input memory.</param>
345+
/// <returns></returns>
346+
internal static Memory<float> ToFloat(this ReadOnlySpan<Float16> inputMemory)
347+
{
348+
var elementCount = inputMemory.Length;
349+
var floatArray = new float[elementCount];
350+
for (int i = 0; i < elementCount; i++)
351+
floatArray[i] = (float)inputMemory[i];
352+
353+
return floatArray.AsMemory();
354+
}
355+
356+
357+
/// <summary>
358+
/// Converts to float.
359+
/// </summary>
360+
/// <param name="inputMemory">The input memory.</param>
361+
/// <returns></returns>
362+
internal static Memory<float> ToFloat(this ReadOnlySpan<BFloat16> inputMemory)
363+
{
364+
var elementCount = inputMemory.Length;
365+
var floatArray = new float[elementCount];
366+
for (int i = 0; i < elementCount; i++)
367+
floatArray[i] = (float)inputMemory[i];
368+
369+
return floatArray.AsMemory();
370+
}
229371
}
230372
}

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,20 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(IModelOption
212212

213213
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
214214
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeDecoder);
215+
var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeDecoder);
216+
var outputTensorMetaData = outputMetaData[outputNames[0]];
215217

216218
var outputDim = new[] { 1, 3, options.Height, options.Width };
217-
var outputBuffer = new DenseTensor<float>(outputDim);
218-
using (var inputTensorValue = latents.ToOrtValue())
219-
using (var outputTensorValue = outputBuffer.ToOrtValue())
219+
using (var inputTensorValue = latents.ToOrtValue(outputTensorMetaData))
220+
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDim))
220221
{
221222
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
222223
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
223224
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputs, outputs);
224225
using (var imageResult = results.First())
225226
{
226227
_logger?.LogEnd("End", timestamp);
227-
return outputBuffer;
228+
return imageResult.ToDenseTensor();
228229
}
229230
}
230231
}
@@ -237,13 +238,16 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(IModelOption
237238
/// <param name="timestepInputName">Name of the timestep input.</param>
238239
/// <param name="timestep">The timestep.</param>
239240
/// <returns></returns>
240-
protected static OrtValue CreateTimestepNamedOrtValue(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata, string timestepInputName, int timestep)
241+
protected static OrtValue CreateTimestepNamedOrtValue(NodeMetadata timestepMetaData, int timestep)
241242
{
242-
// Some models support Long or Float, could be more but fornow just support these 2
243-
var timestepMetaData = nodeMetadata[timestepInputName];
244-
return timestepMetaData.ElementDataType == TensorElementType.Int64
245-
? OrtValue.CreateTensorValueFromMemory(new long[] { timestep }, new long[] { 1 })
246-
: OrtValue.CreateTensorValueFromMemory(new float[] { timestep }, new long[] { 1 });
243+
var dimension = new long[] { 1 };
244+
return timestepMetaData.ElementDataType switch
245+
{
246+
TensorElementType.Int64 => OrtValue.CreateTensorValueFromMemory(new long[] { timestep }, dimension),
247+
TensorElementType.Float16 => OrtValue.CreateTensorValueFromMemory(new Float16[] { (Float16)timestep }, dimension),
248+
TensorElementType.BFloat16 => OrtValue.CreateTensorValueFromMemory(new BFloat16[] { (BFloat16)timestep }, dimension),
249+
_ => OrtValue.CreateTensorValueFromMemory(new float[] { timestep }, dimension) // TODO: Deafult to Float32 for now
250+
};
247251
}
248252

249253

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.Extensions.Logging;
22
using Microsoft.ML.OnnxRuntime;
33
using Microsoft.ML.OnnxRuntime.Tensors;
4+
using OnnxStack.Core;
45
using OnnxStack.Core.Config;
56
using OnnxStack.Core.Services;
67
using OnnxStack.StableDiffusion.Common;
@@ -12,7 +13,6 @@
1213
using System.Collections.Generic;
1314
using System.Linq;
1415
using System.Threading.Tasks;
15-
using OnnxStack.Core;
1616

1717
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
1818
{
@@ -61,19 +61,22 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
6161
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
6262
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder);
6363
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
64+
var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeEncoder);
65+
var outputTensorMetaData = outputMetaData[outputNames[0]];
6466

6567
//TODO: Model Config, Channels
66-
var outputBuffer = new DenseTensor<float>(options.GetScaledDimension());
67-
using (var inputTensorValue = imageTensor.ToOrtValue())
68-
using (var outputTensorValue = outputBuffer.ToOrtValue())
68+
var outputDimension = options.GetScaledDimension();
69+
using (var inputTensorValue = imageTensor.ToOrtValue(outputTensorMetaData))
70+
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension))
6971
{
7072
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
7173
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
7274
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
7375
using (var result = results.First())
7476
{
75-
var scaledSample = outputBuffer
76-
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
77+
var outputResult = outputTensorValue.ToDenseTensor();
78+
var scaledSample = outputResult
79+
.Add(scheduler.CreateRandomSample(outputDimension, options.InitialNoiseLevel))
7780
.MultiplyBy(model.ScaleFactor);
7881

7982
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
107107
var inputNames = _onnxModelService.GetInputNames(modelOptions, OnnxModelType.Unet);
108108
var outputNames = _onnxModelService.GetOutputNames(modelOptions, OnnxModelType.Unet);
109109
var inputMetaData = _onnxModelService.GetInputMetadata(modelOptions, OnnxModelType.Unet);
110+
var outputMetaData = _onnxModelService.GetOutputMetadata(modelOptions, OnnxModelType.Unet);
111+
var timestepMetaData = inputMetaData[inputNames[1]];
112+
var outputTensorMetaData = outputMetaData[outputNames[0]];
110113

111114
// Loop though the timesteps
112115
var step = 0;
@@ -120,17 +123,17 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
120123
var inputTensor = scheduler.ScaleInput(latents, timestep);
121124

122125
var outputChannels = 1;
123-
var outputBuffer = new DenseTensor<float>(schedulerOptions.GetScaledDimension(outputChannels));
124-
using (var outputTensorValue = outputBuffer.ToOrtValue())
125-
using (var inputTensorValue = inputTensor.ToOrtValue())
126-
using (var timestepOrtValue = CreateTimestepNamedOrtValue(inputMetaData, inputNames[1], timestep))
127-
using (var promptTensorValue = promptEmbeddings.ToOrtValue())
128-
using (var guidanceTensorValue = guidanceEmbeddings.ToOrtValue())
126+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
127+
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension))
128+
using (var inputTensorValue = inputTensor.ToOrtValue(outputTensorMetaData))
129+
using (var promptTensorValue = promptEmbeddings.ToOrtValue(outputTensorMetaData))
130+
using (var guidanceTensorValue = guidanceEmbeddings.ToOrtValue(outputTensorMetaData))
131+
using (var timestepTensorValue = CreateTimestepNamedOrtValue(timestepMetaData, timestep))
129132
{
130133
var inputs = new Dictionary<string, OrtValue>
131134
{
132135
{ inputNames[0], inputTensorValue },
133-
{ inputNames[1], timestepOrtValue },
136+
{ inputNames[1], timestepTensorValue },
134137
{ inputNames[2], promptTensorValue },
135138
{ inputNames[3], guidanceTensorValue }
136139
};
@@ -139,7 +142,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
139142
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputs, outputs);
140143
using (var result = results.First())
141144
{
142-
var noisePred = outputBuffer;
145+
var noisePred = outputTensorValue.ToDenseTensor();
143146

144147
// Scheduler Step
145148
var schedulerResult = scheduler.Step(noisePred, timestep, latents);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/ImageDiffuser.cs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.Extensions.Logging;
22
using Microsoft.ML.OnnxRuntime;
33
using Microsoft.ML.OnnxRuntime.Tensors;
4+
using OnnxStack.Core;
45
using OnnxStack.Core.Config;
56
using OnnxStack.Core.Services;
67
using OnnxStack.StableDiffusion.Common;
@@ -12,7 +13,6 @@
1213
using System.Collections.Generic;
1314
using System.Linq;
1415
using System.Threading.Tasks;
15-
using OnnxStack.Core;
1616

1717
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusion
1818
{
@@ -63,19 +63,22 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOpti
6363
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
6464
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeEncoder);
6565
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.VaeEncoder);
66+
var outputMetaData = _onnxModelService.GetOutputMetadata(model, OnnxModelType.VaeEncoder);
67+
var outputTensorMetaData = outputMetaData[outputNames[0]];
6668

6769
//TODO: Model Config, Channels
68-
var outputBuffer = new DenseTensor<float>(options.GetScaledDimension());
69-
using (var inputTensorValue = imageTensor.ToOrtValue())
70-
using (var outputTensorValue = outputBuffer.ToOrtValue())
70+
var outputDimension = options.GetScaledDimension();
71+
using (var inputTensorValue = imageTensor.ToOrtValue(outputTensorMetaData))
72+
using (var outputTensorValue = outputTensorMetaData.CreateOutputBuffer(outputDimension))
7173
{
7274
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
7375
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
7476
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inputs, outputs);
7577
using (var result = results.First())
7678
{
77-
var scaledSample = outputBuffer
78-
.Add(scheduler.CreateRandomSample(outputBuffer.Dimensions, options.InitialNoiseLevel))
79+
var outputResult = outputTensorValue.ToDenseTensor();
80+
var scaledSample = outputResult
81+
.Add(scheduler.CreateRandomSample(outputDimension, options.InitialNoiseLevel))
7982
.MultiplyBy(model.ScaleFactor);
8083

8184
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);

0 commit comments

Comments
 (0)